diff --git a/.gitignore b/.gitignore index 5afe375f46f07b3b557ae23f75740b337517d3bd..1ef4c297ee4f369775c13b32a46a55887de719e7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ __pycache__ *.swp .vscode/ cmake_build/ +tensorflow/contrib/cmake/_build/ .idea/** /build/ [Bb]uild/ @@ -30,6 +31,7 @@ Podfile.lock xcuserdata/** /api_init_files_list.txt /estimator_api_init_files_list.txt +*.whl # Android .gradle diff --git a/CODEOWNERS b/CODEOWNERS index b9f0313cc6d59d3fbdcd014e1a528126d863075a..94cc865479cd6ab5cdb589490d3a2d650f06b160 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,67 @@ -# NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -# /tensorflow/core/platform/windows/ @mrry -# /tensorflow/java/ @asimshankar -# /tensorflow/tensorboard/ @jart @dandelionmane -# /tensorflow/tools/docs/ @markdaoust +/tenosrflow/core/debug @caisq +/tensorflow/core/platform/windows/ @mrry +/tensorflow/core/platform/s3 @yongtang +/tensorflow/go @asimshankar +/tensorflow/java/ @asimshankar +/tensorflow/python/debug @caisq +/tensorflow/python/tools/api/generator/ @annarev +/tensorflow/tensorboard/ @jart +/tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: /tensorflow/contrib/avro/ -# /tensorflow/contrib/batching/ @alextp @chrisolston -# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon -# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva -# /tensorflow/contrib/cmake/ @mrry @benoitsteiner -# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi -# /tensorflow/contrib/crf/ @kentonl -# /tensorflow/contrib/data/ @mrry -# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi -# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo -# /tensorflow/contrib/ffmpeg/ @fredbertsch -# NEED OWNER: /tensorflow/contrib/framework/ -# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/all_reduce +/tensorflow/contrib/batching/ @alextp @chrisolston +/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +/tensorflow/contrib/checkpoint/ @allenlavoie +/tensorflow/contrib/contrib/cluster_resolver/ @frankchn +/tensorflow/contrib/cmake/ @mrry +/tensorflow/contrib/copy_graph/ @tucker @poxvoculi +/tensorflow/contrib/crf/ @kentonl +/tensorflow/contrib/data/ @mrry +/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn +/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +/tensorflow/contrib/eager @alextp @asimshankar +/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +/tensorflow/contrib/ffmpeg/ @fredbertsch +/tensorflow/contrib/framework/ @ebrevdo +/tensorflow/contrib/gan/ @joel-shor +/tensorflow/contrib/graph_editor/ @purpledog # NEED OWNER: /tensorflow/contrib/grid_rnn/ -# /tensorflow/contrib/hvx/ @satok16 -# /tensorflow/contrib/integrate/ @shoyer -# /tensorflow/contrib/kernel_methods/ @petrosmol -# /tensorflow/contrib/ios_examples/ @petewarden -# /tensorflow/contrib/labeled_tensor/ @shoyer -# /tensorflow/contrib/layers/ @fchollet @martinwicke -# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -# /tensorflow/contrib/linalg/ @langmore -# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis -# /tensorflow/contrib/lookup/ @ysuematsu @andreasst -# /tensorflow/contrib/losses/ @alextp @ispirmustafa -# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg -# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa -# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq -# /tensorflow/contrib/opt/ @strategist333 -# /tensorflow/contrib/pi_examples/ @maciekcc -# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman -# /tensorflow/contrib/rnn/ @ebrevdo -# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh -# /tensorflow/contrib/seq2seq/ @lukaszkaiser -# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh -# /tensorflow/contrib/slim/ @sguada @thenbasilmanran -# /tensorflow/contrib/stateless/ @girving -# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -# /tensorflow/contrib/testing/ @dandelionmane -# /tensorflow/contrib/timeseries/ @allenlavoie -# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu -# /tensorflow/contrib/training/ @joel-shor @ebrevdo -# /tensorflow/contrib/util/ @sherrym +/tensorflow/contrib/hadoop @yongtang +/tensorflow/contrib/hvx/ @satok16 +/tensorflow/contrib/integrate/ @shoyer +/tensorflow/contrib/kafka @yongtang +/tensorflow/contrib/kernel_methods/ @petrosmol +/tensorflow/contrib/kinesis @yongtang +/tensorflow/contrib/ios_examples/ @petewarden +/tensorflow/contrib/labeled_tensor/ @shoyer +/tensorflow/contrib/layers/ @fchollet @martinwicke +/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +/tensorflow/contrib/lookup/ @ysuematsu @andreasst +/tensorflow/contrib/losses/ @alextp @ispirmustafa +/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +/tensorflow/contrib/opt/ @strategist333 @alextp +/tensorflow/contrib/pi_examples/ @maciekcc +/tensorflow/contrib/quantization/ @petewarden +/tensorflow/contrib/rnn/ @ebrevdo @scottzhu +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang +/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +/tensorflow/contrib/slim/ @sguada @thenbasilmanran +/tensorflow/contrib/stateless/ @girving @alextp +/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank +/tensorflow/contrib/tensorrt/ @aaroey +# NEED OWNER: /tensorflow/contrib/testing/ +/tensorflow/contrib/timeseries/ @allenlavoie +/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj +/tensorflow/contrib/training/ @joel-shor @ebrevdo +/tensorflow/contrib/util/ @sherrym + +/third_party/systemlibs/ @perfinion diff --git a/README.md b/README.md index 82de010dd445c57c3fcc566db53e18db025c1f9e..57efb876c9afaf9fe76c4ced4e6a1572e9241edf 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,14 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. +TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift. + Keep up to date with release announcements and security updates by subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). ## Installation -*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* +*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.* People who are a little more adventurous can also try our nightly binaries: @@ -46,15 +48,12 @@ $ python ``` ```python >>> import tensorflow as tf +>>> tf.enable_eager_execution() +>>> tf.add(1, 2) +3 >>> hello = tf.constant('Hello, TensorFlow!') ->>> sess = tf.Session() ->>> sess.run(hello) +>>> hello.numpy() 'Hello, TensorFlow!' ->>> a = tf.constant(10) ->>> b = tf.constant(32) ->>> sess.run(a + b) -42 ->>> sess.close() ``` Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). @@ -81,13 +80,15 @@ The TensorFlow project strives to abide by generally accepted best practices in | Build Type | Status | Artifacts | | --- | --- | --- | -| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Linux XLA** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.png) | TBA | -| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Android** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.png) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | +| **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA | +| **MacOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | +| **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) | +| **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) | ### Community Supported Builds @@ -97,17 +98,20 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | +| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | ## For more information - * [TensorFlow Website](https://www.tensorflow.org) -* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) -* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) +* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) +* [TensorFlow Twitter](https://twitter.com/tensorflow) +* [TensorFlow Blog](https://medium.com/tensorflow) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) +* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index 078aafd3746e5ce5c16af15de80d99c1a9e8c567..20e1d9217b7684e696d0abf427eef9ab9548d1b7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,9 +1,92 @@ +# Release 1.11.0 + +## Major Features and Improvements + +* Nvidia GPU: + * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) +* Google Cloud TPU: + * Experimental tf.data integration for Keras on Google Cloud TPUs. + * Experimental / preview support for eager execution on Google Cloud TPUs. +* DistributionStrategy: + * Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs. + * Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details. +* Add C, C++, and Python functions for querying kernels + +## Breaking Changes + +* Keras: + * The default values for tf.keras `RandomUniform`, `RandomNormal`, and `TruncatedNormal` initializers have been changed to match those in external Keras. + * Breaking change: `model.get_config()` on a Sequential model now returns a config dictionary (consistent with other Model instances) instead of a list of configs for the underlying layers. + +## Bug Fixes and Other Changes + +* C++: + * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure. +* tf.data: + * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. + * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files. + * Renamed BigTable class to BigtableTable for clarity + * Document use of the Cloud Bigtable API + * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element. + * Generalization of `tf.contrib.data.sliding_window_batch`. +* INC: + * Runtime improvements to triangular solve. +* `tf.contrib`: + * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`. + * Add documentation clarifying the differences between tf.fill and tf.constant. + * Add experimental IndexedDatasets. + * Add selective registration target using the lite proto runtime. + * Add simple Tensor and DataType classes to TensorFlow Lite Java + * Add support for bitcasting to/from uint32 and uint64. + * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator). + * Adds leaf index modes as an argument. + * Allow a different output shape from the input in tf.contrib.image.transform. + * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells. + * Deprecate self.test_session() in favor of self.session() or self.cached_session(). + * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon) + * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one. + * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator. + * Fix toco compilation/execution on Windows + * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in. + * It is now safe to call any of the C API's TF_Delete\* functions on nullptr + * Log some errors on Android to logcat + * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models. + * Optional bucket location check for the GCS Filesystem. + * Performance enhancements for StringSplitOp & StringSplitV2Op. + * Performance improvements for regex replace operations. + * TFRecordWriter now raises an error if .write() fails. + * TPU: More helpful error messages in TPUClusterResolvers. + * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time. + * The protocol used for Estimator training is now configurable in RunConfig. + * Triangular solve performance improvements. + * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method. + * Update initialization of variables in Keras. + * Updates to "constrained_optimization" in tensorflow/contrib. + * boosted trees: adding pruning mode + * tf.train.Checkpoint does not delete old checkpoints by default. + * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Aapeli, adoda, Ag Ramesh, Amogh Mannekote, Andrew Gibiansky, Andy Craze, Anirudh Koul, Aurelien Geron, Avijit, Avijit-Nervana, Ben, Benjamin H. Myara, bhack, Brett Koonce, Cao Zongyan, cbockman, cheerss, Chikanaga Tomoyuki, Clayne Robison, cosine0, Cui Wei, Dan J, David, David Norman, Dmitry Klimenkov, Eliel Hojman, Florian Courtial, fo40225, formath, Geoffrey Irving, gracehoney, Grzegorz Pawelczak, Guoliang Hua, Guozhong Zhuang, Herman Zvonimir DošIlović, HuiyangFei, Jacker, Jan HüNnemeyer, Jason Taylor, Jason Zaman, Jesse, Jiang,Zhoulong, Jiawei Zhang, Jie, Joe Yearsley, Johannes Schmitz, Jon Perl, Jon Triebenbach, Jonathan, Jonathan Hseu, Jongmin Park, Justin Shenk, karl@kubx.ca, Kate Hodesdon, Kb Sriram, Keishi Hattori, Kenneth Blomqvist, Koan-Sin Tan, Li Liangbin, Li, Yiqiang, Loo Rong Jie, Madiyar, Mahmoud Abuzaina, Mark Ryan, Matt Dodge, mbhuiyan, melvinljy96, Miguel Mota, Nafis Sadat, Nathan Luehr, naurril, Nehal J Wani, Niall Moran, Niranjan Hasabnis, Nishidha Panpaliya, npow, olicht, Pei Zhang, Peng Wang (Simpeng), Peng Yu, Philipp Jund, Pradeep Banavara, Pratik Kalshetti, qwertWZ, Rakesh Chada, Randy West, Ray Kim, Rholais Lii, Robin Richtsfeld, Rodrigo Silveira, Ruizhi, Santosh Kumar, Seb Bro, Sergei Lebedev, sfujiwara, Shaba Abhiram, Shashi, SneakyFish5, Soila Kavulya, Stefan Dyulgerov, Steven Winston, Sunitha Kambhampati, Surry Shome, Taehoon Lee, Thor Johnsen, Tristan Rice, TShapinsky, tucan, tucan9389, Vicente Reyes, Vilmar-Hillow, Vitaly Lavrukhin, wangershi, weidan.kong, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Wim Glenn, XFeiF, Yan Facai (颜发才), Yanbo Liang, Yong Tang, Yoshihiro Yamazaki, Yuan (Terry) Tang, Yuan, Man, zhaoyongke, ÁRon +Ricardo Perez-Lopez, 张天启, 张晓飞 + + +# Release 1.10.1 +## Bug Fixes and Other Changes + +* `tf.keras`: + * Fixing keras on Cloud TPUs. No new binaries will be built for Windows. + + # Release 1.10.0 ## Major Features And Improvements * The `tf.lite` runtime now supports `complex64`. -* Initial Bigtable integration for `tf.data`. +* Initial [Google Cloud Bigtable integration](https://github.com/tensorflow/tensorflow/tree/r1.10/tensorflow/contrib/bigtable) for `tf.data`. * Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation. * `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`. * Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018. @@ -11,7 +94,7 @@ ## Breaking Changes -* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [TensorFlow GPU support](https://www.tensorflow.org/install/gpu) and [Build TensorFlow from source](https://www.tensorflow.org/install/source). * Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. ## Bug Fixes and Other Changes @@ -19,7 +102,7 @@ * `tf.data`: * `tf.contrib.data.group_by_reducer()` is now available via the public API. * `tf.contrib.data.choose_from_datasets()` is now available via the public API. - * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. + * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating `tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. * `tf.estimator`: * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export. * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases. diff --git a/configure.py b/configure.py index f97bf8a66836a6647ba6aca625cb1526e11b39af..3fcaaa9d0ef51c57fb40fcafd8579977f37375ef 100644 --- a/configure.py +++ b/configure.py @@ -41,11 +41,10 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -54,6 +53,11 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) _TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') +if platform.machine() == 'ppc64le': + _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' +else: + _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() + class UserInputError(Exception): pass @@ -153,14 +157,18 @@ def get_python_path(environ_cp, python_bin_path): if environ_cp.get('PYTHONPATH'): python_paths = environ_cp.get('PYTHONPATH').split(':') try: - library_paths = run_shell( - [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') + library_paths = run_shell([ + python_bin_path, '-c', + 'import site; print("\\n".join(site.getsitepackages()))' + ]).split('\n') except subprocess.CalledProcessError: - library_paths = [run_shell( - [python_bin_path, '-c', - 'from distutils.sysconfig import get_python_lib;' - 'print(get_python_lib())'])] + library_paths = [ + run_shell([ + python_bin_path, '-c', + 'from distutils.sysconfig import get_python_lib;' + 'print(get_python_lib())' + ]) + ] all_paths = set(python_paths + library_paths) @@ -187,8 +195,7 @@ def setup_python(environ_cp): environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path, default_python_bin_path) # Check if the path is valid - if os.path.isfile(python_bin_path) and os.access( - python_bin_path, os.X_OK): + if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break elif not os.path.exists(python_bin_path): print('Invalid python path: %s cannot be found.' % python_bin_path) @@ -230,8 +237,9 @@ def setup_python(environ_cp): environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open(os.path.join( - _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: + with open( + os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), + 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) @@ -250,7 +258,7 @@ def reset_tf_configure_bazelrc(workspace_path): continue f.write('%s\n' % l) if is_windows(): - tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") + tf_bazelrc_path = _TF_BAZELRC.replace('\\', '/') else: tf_bazelrc_path = _TF_BAZELRC f.write('import %s\n' % tf_bazelrc_path) @@ -261,8 +269,8 @@ def cleanup_makefile(): These files could interfere with Bazel parsing. """ - makefile_download_dir = os.path.join( - _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') + makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow', + 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -330,9 +338,8 @@ def get_var(environ_cp, 'Environment variable %s must be set as a boolean indicator.\n' 'The following are accepted as TRUE : %s.\n' 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % ( - var_name, ', '.join(true_strings), ', '.join(false_strings), - var)) + 'Current value is %s.' % (var_name, ', '.join(true_strings), + ', '.join(false_strings), var)) while var is None: user_input_origin = get_input(question) @@ -355,8 +362,12 @@ def get_var(environ_cp, return var -def set_build_var(environ_cp, var_name, query_item, option_name, - enabled_by_default, bazel_config_name=None): +def set_build_var(environ_cp, + var_name, + query_item, + option_name, + enabled_by_default, + bazel_config_name=None): """Set if query_item will be enabled for the build. Ask user if query_item will be enabled. Default is used if no input is given. @@ -379,8 +390,8 @@ def set_build_var(environ_cp, var_name, query_item, option_name, elif bazel_config_name is not None: # TODO(mikecase): Migrate all users of configure.py to use --config Bazel # options and not to set build configs through environment variables. - write_to_bazelrc('build:%s --define %s=true' - % (bazel_config_name, option_name)) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) def set_action_env_var(environ_cp, @@ -447,7 +458,8 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) + curr_version = run_shell( + ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -499,6 +511,7 @@ def set_cc_opt_flags(environ_cp): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') + def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -581,16 +594,14 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) -def prompt_loop_or_load_from_env( - environ_cp, - var_name, - var_default, - ask_for_var, - check_success, - error_msg, - suppress_default_error=False, - n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS -): +def prompt_loop_or_load_from_env(environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS): """Loop over user prompts for an ENV param until receiving a valid response. For the env param var_name, read from the environment or verify user input @@ -629,9 +640,7 @@ def prompt_loop_or_load_from_env( ) for _ in range(n_ask_attempts): - val = get_from_env_or_user_or_default(environ_cp, - var_name, - full_query, + val = get_from_env_or_user_or_default(environ_cp, var_name, full_query, default) if check_success(val): break @@ -639,9 +648,9 @@ def prompt_loop_or_load_from_env( print(error_msg % val) environ_cp[var_name] = '' else: - raise UserInputError('Invalid %s setting was provided %d times in a row. ' - 'Assuming to be a scripting mistake.' % - (var_name, n_ask_attempts)) + raise UserInputError( + 'Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts)) environ_cp[var_name] = val return val @@ -650,8 +659,8 @@ def prompt_loop_or_load_from_env( def create_android_ndk_rule(environ_cp): """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" if is_windows() or is_cygwin(): - default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % - environ_cp['APPDATA']) + default_ndk_path = cygpath( + '%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA']) elif is_macos(): default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] else: @@ -668,8 +677,7 @@ def create_android_ndk_rule(environ_cp): ask_for_var='Please specify the home path of the Android NDK to use.', check_success=valid_ndk_path, error_msg=('The path %s or its child file "source.properties" ' - 'does not exist.') - ) + 'does not exist.')) write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', check_ndk_level(android_ndk_home_path)) @@ -703,9 +711,9 @@ def create_android_sdk_rule(environ_cp): api_levels = [x.replace('android-', '') for x in api_levels] def valid_api_level(api_level): - return os.path.exists(os.path.join(android_sdk_home_path, - 'platforms', - 'android-' + api_level)) + return os.path.exists( + os.path.join(android_sdk_home_path, 'platforms', + 'android-' + api_level)) android_api_level = prompt_loop_or_load_from_env( environ_cp, @@ -720,9 +728,8 @@ def create_android_sdk_rule(environ_cp): versions = sorted(os.listdir(build_tools)) def valid_build_tools(version): - return os.path.exists(os.path.join(android_sdk_home_path, - 'build-tools', - version)) + return os.path.exists( + os.path.join(android_sdk_home_path, 'build-tools', version)) android_build_tools_version = prompt_loop_or_load_from_env( environ_cp, @@ -736,10 +743,8 @@ def create_android_sdk_rule(environ_cp): write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', android_build_tools_version) - write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', - android_api_level) - write_action_env_to_bazelrc('ANDROID_SDK_HOME', - android_sdk_home_path) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -798,6 +803,7 @@ def reformat_version_sequence(version_str, sequence_count): Args: version_str: String, the version string. sequence_count: int, an integer. + Returns: string, reformatted version string. """ @@ -839,19 +845,27 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_path = cygpath(cuda_toolkit_path) if is_windows(): - cuda_rt_lib_path = 'lib/x64/cudart.lib' + cuda_rt_lib_paths = ['lib/x64/cudart.lib'] elif is_linux(): - cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version + cuda_rt_lib_paths = [ + '%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [ + 'lib64', + 'lib/powerpc64le-linux-gnu', + 'lib/x86_64-linux-gnu', + ] + ] elif is_macos(): - cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version + cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path) - if os.path.exists(cuda_toolkit_path_full): + cuda_toolkit_paths_full = [ + os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths + ] + if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): break # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % - (tf_cuda_version, cuda_toolkit_path_full)) + (tf_cuda_version, cuda_toolkit_paths_full)) environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' @@ -918,8 +932,8 @@ def set_tf_cudnn_version(environ_cp): cudnn_path_from_ldconfig) if cudnn_path_from_ldconfig: cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) - if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, - tf_cudnn_version)): + if os.path.exists( + '%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) break @@ -1165,6 +1179,7 @@ def get_native_cuda_compute_capabilities(environ_cp): Args: environ_cp: copy of the os.environ. + Returns: string of native cuda compute capabilities, separated by comma. """ @@ -1289,8 +1304,7 @@ def set_computecpp_toolkit_path(environ_cp): else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(toolkit_path, - sycl_rt_lib_path) + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) exists = os.path.exists(sycl_rt_lib_path_full) if not exists: print('Invalid SYCL %s library path. %s cannot be found' % @@ -1318,8 +1332,8 @@ def set_trisycl_include_dir(environ_cp): ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' 'include directory. (Use --config=sycl_trisycl ' 'when building with Bazel) ' - '[Default is %s]: ' - ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + '[Default is %s]: ') % ( + _DEFAULT_TRISYCL_INCLUDE_DIR) while True: trisycl_include_dir = get_from_env_or_user_or_default( @@ -1328,13 +1342,12 @@ def set_trisycl_include_dir(environ_cp): if os.path.exists(trisycl_include_dir): break - print('Invalid triSYCL include directory, %s cannot be found' - % (trisycl_include_dir)) + print('Invalid triSYCL include directory, %s cannot be found' % + (trisycl_include_dir)) # Set TRISYCL_INCLUDE_DIR environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', - trisycl_include_dir) + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) def set_mpi_home(environ_cp): @@ -1344,8 +1357,9 @@ def set_mpi_home(environ_cp): default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) def valid_mpi_path(mpi_home): - exists = (os.path.exists(os.path.join(mpi_home, 'include')) and - os.path.exists(os.path.join(mpi_home, 'lib'))) + exists = ( + os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) if not exists: print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % (os.path.join(mpi_home, 'include'), @@ -1394,12 +1408,21 @@ def set_other_mpi_vars(environ_cp): raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) -def set_grpc_build_flags(): - write_to_bazelrc('build --define grpc_no_ares=true') - +def set_system_libs_flag(environ_cp): + syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') + if syslibs and syslibs != '': + if ',' in syslibs: + syslibs = ','.join(sorted(syslibs.split(','))) + else: + syslibs = ','.join(sorted(syslibs.split())) + write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) -def set_build_strip_flag(): - write_to_bazelrc('build --strip=always') + if 'PREFIX' in environ_cp: + write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) + if 'LIBDIR' in environ_cp: + write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) + if 'INCLUDEDIR' in environ_cp: + write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) def set_windows_build_flags(environ_cp): @@ -1417,14 +1440,20 @@ def set_windows_build_flags(environ_cp): # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 # Short object file path will be enabled by default. write_to_bazelrc('build --experimental_shortened_obj_file_path=true') + # When building zip file for some py_binary and py_test targets, don't + # include its dependencies. This is for: + # 1. Running python tests against the system installed TF pip package. + # 2. Avoiding redundant files in + # //tensorflow/tools/pip_package:simple_console_windows, + # which is a py_binary used during creating TF pip package. + # See https://github.com/tensorflow/tensorflow/issues/22390 + write_to_bazelrc('build --define=no_tensorflow_py_deps=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', - True, - ('Would you like to override eigen strong inline for some C++ ' - 'compilation to reduce the compilation time?'), - 'Eigen strong inline overridden.', - 'Not overriding eigen strong inline, ' + True, ('Would you like to override eigen strong inline for some C++ ' + 'compilation to reduce the compilation time?'), + 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, ' 'some compilations could take more than 20 mins.'): # Due to a known MSVC compiler issue # https://github.com/tensorflow/tensorflow/issues/10521 @@ -1441,10 +1470,11 @@ def config_info_line(name, help_text): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--workspace", - type=str, - default=_TF_WORKSPACE_ROOT, - help="The absolute path to your active Bazel workspace.") + parser.add_argument( + '--workspace', + type=str, + default=_TF_WORKSPACE_ROOT, + help='The absolute path to your active Bazel workspace.') args = parser.parse_args() # Make a copy of os.environ to be clear when functions and getting and setting @@ -1472,8 +1502,6 @@ def main(): # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_ENABLE_XLA'] = '0' - environ_cp['TF_NEED_GDR'] = '0' - environ_cp['TF_NEED_VERBS'] = '0' environ_cp['TF_NEED_MPI'] = '0' environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' @@ -1486,7 +1514,7 @@ def main(): # runtime to allow the Tensorflow testcases which compare numpy # results to Tensorflow results to succeed. if is_ppc64le(): - write_action_env_to_bazelrc("OMP_NUM_THREADS", 1) + write_action_env_to_bazelrc('OMP_NUM_THREADS', 1) set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) @@ -1500,10 +1528,7 @@ def main(): 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') - set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', - False, 'gdr') - set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', - False, 'verbs') + set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': @@ -1537,6 +1562,10 @@ def main(): if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) + else: + # Use downloaded LLD for linking. + write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld') + write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld') else: # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -1556,19 +1585,18 @@ def main(): set_mpi_home(environ_cp) set_other_mpi_vars(environ_cp) - set_grpc_build_flags() set_cc_opt_flags(environ_cp) - set_build_strip_flag() + set_system_libs_flag(environ_cp) if is_windows(): set_windows_build_flags(environ_cp) - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): + # Add a config option to build TensorFlow 2.0 API. + write_to_bazelrc('build:v2 --define=tf_api_version=2') + + if get_var(environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) @@ -1581,6 +1609,11 @@ def main(): 'more details.') config_info_line('mkl', 'Build with MKL support.') config_info_line('monolithic', 'Config for mostly static monolithic build.') + config_info_line('gdr', 'Build with GDR support.') + config_info_line('verbs', 'Build with libverbs support.') + config_info_line('ngraph', 'Build with Intel ngraph support.') + if __name__ == '__main__': main() + diff --git a/tensorflow/BUILD b/tensorflow/BUILD index e13a5cf802ece5fd53c1ca2db931a548aa7fe451..3610eea42a58ab74940e059736dd692713d001f1 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -12,6 +12,7 @@ exports_files([ # The leakr files are used by //third_party/cloud_tpu. "leakr_badwords.dic", "leakr_badfiles.dic", + "leakr_file_type_recipe.ftrcp", ]) load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") @@ -23,6 +24,24 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files") +load( + "//tensorflow/python/tools/api/generator:api_init_files.bzl", + "TENSORFLOW_API_INIT_FILES", # @unused +) +load( + "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", + "TENSORFLOW_API_INIT_FILES_V1", # @unused +) +load( + "//third_party/ngraph:build_defs.bzl", + "if_ngraph", +) + +# @unused +TENSORFLOW_API_INIT_FILES_V2 = ( + TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) # Config setting used when building for products # which requires restricted licenses to be avoided. @@ -123,12 +142,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "no_tensorflow_py_deps", define_values = {"no_tensorflow_py_deps": "true"}, @@ -387,6 +400,7 @@ config_setting( define_values = { "dynamic_loaded_kernels": "true", }, + visibility = ["//visibility:public"], ) config_setting( @@ -416,12 +430,28 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag is set from the configure step when the user selects with nGraph option. +# By default it should be false +config_setting( + name = "with_ngraph_support", + values = {"define": "with_ngraph_support=true"}, + visibility = ["//visibility:public"], +) + +# This flag specifies whether TensorFlow 2.0 API should be built instead +# of 1.* API. Note that TensorFlow 2.0 API is currently under development. +config_setting( + name = "api_version_2", + define_values = {"tf_api_version": "2"}, +) + package_group( name = "internal", packages = [ "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", + "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", "//third_party/py/tensor2tensor/...", ], @@ -429,12 +459,12 @@ package_group( load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "if_mkl_ml", ) filegroup( name = "intel_binary_blob", - data = if_mkl( + data = if_mkl_ml( [ "//third_party/mkl:intel_binary_blob", ], @@ -487,7 +517,6 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_framework_version_script.lds)", @@ -529,13 +558,13 @@ tf_cc_shared_object( "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow/c:version_script.lds)", ], }), + visibility = ["//visibility:public"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -554,13 +583,13 @@ tf_cc_shared_object( "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_version_script.lds)", ], }), + visibility = ["//visibility:public"], deps = [ "//tensorflow:tf_exported_symbols.lds", "//tensorflow:tf_version_script.lds", @@ -571,7 +600,7 @@ tf_cc_shared_object( "//tensorflow/cc:scope", "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", - ], + ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]), ) exports_files( @@ -581,9 +610,73 @@ exports_files( ], ) +genrule( + name = "install_headers", + srcs = [ + "//tensorflow/c:headers", + "//tensorflow/c/eager:headers", + "//tensorflow/cc:headers", + "//tensorflow/core:headers", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(SRCS); do + d="$${f%/*}" + d="$${d#bazel-out*genfiles/}" + d="$${d#*external/eigen_archive/}" + + if [[ $${d} == *local_config_* ]]; then + continue + fi + + if [[ $${d} == external* ]]; then + extname="$${d#*external/}" + extname="$${extname%%/*}" + if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then + continue + fi + fi + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """, + tags = ["manual"], + visibility = ["//visibility:public"], +) + +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + +gen_api_init_files( + name = "tf_python_api_gen_v1", + srcs = ["api_template.__init__.py"], + api_version = 1, + output_dir = "_api/v1/", + output_files = TENSORFLOW_API_INIT_FILES_V1, + output_package = "tensorflow._api.v1", + root_init_template = "api_template.__init__.py", +) + gen_api_init_files( - name = "tensorflow_python_api_gen", + name = "tf_python_api_gen_v2", srcs = ["api_template.__init__.py"], + api_version = 2, + compat_api_versions = [1], + output_dir = "_api/v2/", + output_files = TENSORFLOW_API_INIT_FILES_V2, + output_package = "tensorflow._api.v2", root_init_template = "api_template.__init__.py", ) @@ -601,7 +694,10 @@ py_library( py_library( name = "tensorflow_py_no_contrib", - srcs = [":tensorflow_python_api_gen"], + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }) + [":root_init_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446..21677512b63828fa2035527ed573bf4dc4603085 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -28,7 +28,8 @@ contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top -app.flags = flags # pylint: disable=undefined-variable +from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top +app.flags = flags del absolute_import del division diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 779f65d5b17c350833f67f07985b00e8eb561e72..2de740e145f93b151faf5c987808dbdf73fb4fd7 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -14,15 +14,16 @@ # ============================================================================== """Bring in all of the public TensorFlow interface into this module.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import try: - import os # pylint: disable=g-import-not-at-top # Add `estimator` attribute to allow access to estimator APIs via # "tf.estimator..." from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top @@ -30,9 +31,8 @@ try: # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." # style imports. from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top - __path__ += [os.path.dirname(estimator_api.__file__)] + __path__ += [_os.path.dirname(estimator_api.__file__)] del estimator_api - del os except (ImportError, AttributeError): print('tf.estimator package not installed.') @@ -41,19 +41,32 @@ except (ImportError, AttributeError): from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader +# The templated code that replaces the placeholder above sometimes +# sets the __all__ variable. If it does, we have to be sure to add +# "contrib". +if '__all__' in vars(): + vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable -del absolute_import -del division -del print_function +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +if _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the # resolution to succeed. # pylint: disable=undefined-variable -del python -del core +try: + del python + del core +except NameError: + # Don't fail if these modules are not available. + # For e.g. we are using this file for compat.v1 module as well and + # 'python', 'core' directories are not under compat/v1. + pass # pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 8a9301d584775cff3ae315e6fd856b00d1734248..17e2e292eb19029d279bc12a8328edadf96f1bb8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -117,6 +117,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/c/eager:c_api", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", @@ -127,6 +128,15 @@ tf_cuda_library( ], ) +cc_library( + name = "c_api_headers", + hdrs = [ + "c_api.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], +) + exports_files( [ "version_script.lds", @@ -194,6 +204,7 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -235,6 +246,7 @@ tf_cc_test( ":c_api_experimental", ":c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 19ccb6e71d2f3021c1ce5c8905d8a72059c1cfcb..79811ceae57e0bddeb2a6f32bad7003e14e23422 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -202,7 +203,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->len_ = len; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && - reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { + reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != + 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste // (any alignment requirements will be taken care of by TF_TensorToTensor @@ -1239,7 +1241,7 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; - func_name.set_name(std::string(value, value + length)); + func_name.set_name(string(value, value + length)); desc->node_builder.Attr(attr_name, func_name); } @@ -2064,7 +2066,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); + tf_results->missing_unused_key_names_data.emplace_back(id.first); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 69b3ffe2a1f620e346405607ecf742fb863aa644..f316e4ba6735213ba2fbbc1f8c019ad235c0df1f 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,11 +17,13 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::FunctionDef; using tensorflow::Node; @@ -79,6 +81,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a + // threadpool, so that we avoid the possibility of running the runner_ in the + // threadpool of GPU event mgr, as that can trigger more callbacks to be + // scheduled on that same threadpool, causing a deadlock in cases where the + // caller of event_mgr->ThenExecute() blocks on the completion of the callback + // (as in the case of ConstOp kernel creation on GPU, which involves copying a + // CPU tensor to GPU). + // Setting a larger thread pool does not help with the Swift caller, as we use + // a different TFE context for each thread of execution (for running graph + // functions, and their send/recvs corountines). + config.set_inter_op_parallelism_threads(1); + TF_Buffer* ret = TF_NewBuffer(); TF_CHECK_OK(MessageToBuffer(config, ret)); return ret; @@ -8494,3 +8508,265 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, /*run_metadata*/ nullptr, status); VLOG(1) << "Enqueuing is done."; } + +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { + tensorflow::ServerDef server_def; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, + &server_def)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for ServerDef: ", text_proto); + return nullptr; + } + status->status = tensorflow::Status(); + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(server_def, ret)); + return ret; +} + +TFE_Context* TFE_CreateContextFromSession(TF_Session* session, + TF_Status* status) { + auto* opts = TFE_NewContextOptions(); + + // Reduce GPU memory allocation, and set appropriate config options for TFE + // context. + auto* config = + TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); + if (!status->status.ok()) { + CHECK(!config); + TFE_DeleteContextOptions(opts); + return nullptr; + } + + auto* ctx = TFE_NewContextFromSession(opts, session, status); + TF_DeleteBuffer(config); + TFE_DeleteContextOptions(opts); + return ctx; +} + +// TODO: retrieve the device string via TFE_ContextListDevices() +static const char DEFAULT_CPU_DEVICE[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + +static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, + int tensor_id, TF_Status* status) { + std::unique_ptr queueOp( + TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); + TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. + TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); + TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); + auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); + TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), + shared_name.size()); + TFE_OpSetAttrString(queueOp.get(), "container", "", 0); + + // TODO: consider making this an unknown shape. + const int64_t* dims_ptr = nullptr; + int num_dims = 0; + TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, + /*num_values*/ 0, status); + if (!status->status.ok()) return nullptr; + + int num_retvals = 1; + TFE_TensorHandle* queue = nullptr; + TFE_Execute(queueOp.get(), &queue, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + + return queue; +} + +static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, + TFE_TensorHandle* queue, TFE_TensorHandle* tensor, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); + if (!status->status.ok()) return; + std::unique_ptr op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, tensor, status); + if (!status->status.ok()) return; + TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + + int num_retvals = 0; + TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); + if (!status->status.ok()) return; + CHECK_EQ(num_retvals, 0); +} + +static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, + TF_DataType inputType, + TFE_TensorHandle* queue, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); + if (!status->status.ok()) return nullptr; + std::unique_ptr op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return nullptr; + TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + TFE_TensorHandle* ret; + int num_retvals = 1; + TFE_Execute(op, &ret, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + assert(session); + VLOG(1) << "Dequeuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + + return ret; +} + +void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + assert(session); + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status) { + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); +} + +TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + return createTFEDequeue(ctx, TF_VARIANT, queue, status); +} + +static void CheckOk(TF_Status* status) { + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); +} + +void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { + auto* status = TF_NewStatus(); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::Tensor dst; + TF_CHECK_OK(TF_TensorToTensor(t, &dst)); + LOG(INFO) << dst.DebugString(); + + TF_DeleteTensor(t); + TF_DeleteStatus(status); +} + +TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) { + // Intentionally LOG into INFO below for ease of debugging. + VLOG(1) << "TFE_RunConstOp called"; + + auto* status = TF_NewStatus(); + auto* op = TFE_NewOp(ctx, "Const", status); + CheckOk(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + + auto* tensor = + TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0, + TF_DataTypeSize(TF_FLOAT) * 1); + auto* ptr = reinterpret_cast(TF_TensorData(tensor)); + *reinterpret_cast(ptr) = 17.0; + + TFE_OpSetAttrTensor(op, "value", tensor, status); + CheckOk(status); + TF_DeleteTensor(tensor); + VLOG(1) << "New op created"; + + TFE_TensorHandle* retval; + int num_retvals = 1; + TFE_Execute(op, &retval, &num_retvals, status); + CheckOk(status); + CHECK_EQ(num_retvals, 1); + VLOG(1) << "Op executed"; + + TFE_DeleteOp(op); + TF_DeleteStatus(status); + + return retval; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6617c5a572e90e78369f73d714f39942f213040f..950ad9aeed6f883fa22c2673fa8aa92839cd0fbc 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" // -------------------------------------------------------------------------- // Experimental C API for TensorFlow. @@ -130,6 +131,59 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, TF_Tensor* tensor, TF_Status* status); +// Create a serialized tensorflow.ServerDef proto. +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); + +// TODO: remove this API in favor of the next one. +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( + const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); + +// Creates from `session` a new eager context to run a graph function or +// sends/recvs, so that these concurrent TFE executions can share (via +// `session` and its associated device mgr) the same set of fifo queue resource +// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and +// graph function execution can access the same fifo queue resource handles +// (associated with devices managed by the device manager, which can be obtained +// from `session`). +// +// TODO: Remove this function once we migrate away from using session. +TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( + TF_Session* session, TF_Status* status); + +// TODO: Retire this API in favor of the next one. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( + TF_Session* session, int tensor_id, TF_DataType inputType, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, + TF_Status* status); + +// TODO: consider folding the 2 APIs below into the ones above. +TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( + TF_Session* session, int tensor_id, TF_Status* status); + +// Prints `handle` in a human readable format to standard output for debugging. +TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( + TFE_TensorHandle* handle); + +// Returns a const scalar tensor. +// Caller owns both the input and the output tensor handles. +// TODO: Remove this API with hard-coded tensor computation. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 30fcfd401d9d634962d64aaa3bf348de91f2ecae..c6effd39697e0397278770b53e98508074f99862 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace tensorflow { namespace { @@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { TF_DeleteStatus(s); } +TEST(CAPI_EXPERIMENTAL, GetServerDefTest) { + const string expected_text_proto(R"(cluster { + job { + name: "worker" + tasks { + key: 0 + value: "tpuserver:0" + } + tasks { + key: 1 + value: "localhost:1" + } + } +} +job_name: "worker" +task_index: 1 +protocol: "grpc" +)"); + + TF_Status* status = TF_NewStatus(); + TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK); + + ServerDef actual; + ASSERT_TRUE(actual.ParseFromArray(result->data, result->length)); + string actual_text_proto; + tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto); + EXPECT_EQ(expected_text_proto, actual_text_proto); + + const string malformed_text_proto(R"(cluster { + job { + name: "worker")"); + TF_Buffer* null_result = + TFE_GetServerDef(malformed_text_proto.c_str(), status); + EXPECT_NE(TF_GetCode(status), TF_OK); + EXPECT_TRUE(tensorflow::str_util::StrContains( + TF_Message(status), "Invalid text proto for ServerDef")); + EXPECT_EQ(null_result, nullptr); + + // Cleanup + TF_DeleteBuffer(result); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a2c5a42c11361779de61b515e0f08dcc45e609b9..f68f8a3e90a971b5e4a024feaf26ba498afc48da 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/strings/base64.h" diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index bb9433ce25e0e3b9cfb54698c940cc1b38c88d31..73fe73769bc1219ce865149d67d333c53371ccc5 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index aa2a537f03be31ae45ff3d6f7815b449d661cf9c..03516c39dc970aa23967107d3a0446da94669465 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -259,8 +259,8 @@ TEST(CAPI, DeprecatedSession) { TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0, nullptr, 0, run_metadata, s); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(std::string("Session was not created with a graph before Run()!"), - std::string(TF_Message(s))); + EXPECT_EQ("Session was not created with a graph before Run()!", + string(TF_Message(s))); TF_DeleteBuffer(run_metadata); TF_DeleteBuffer(run_options); @@ -1224,8 +1224,8 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), - std::string(TF_Message(s_))); + EXPECT_EQ("Operation 'add' has no attr named '_class'.", + string(TF_Message(s_))); return; } EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); @@ -1369,16 +1369,16 @@ TEST(CAPI, SavedModel) { input.flat()(i) = example.SerializeAsString(); } - const tensorflow::string input_op_name = - std::string(tensorflow::ParseTensorName(input_name).first); + const tensorflow::string input_op_name( + tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - const tensorflow::string output_op_name = - std::string(tensorflow::ParseTensorName(output_name).first); + const tensorflow::string output_op_name( + tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b21349fce288db3e79fbf14e824ad11..f15d9ee20adb31a0b76e2cd0d1e67f17a9deff05 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 38313d647ca93d4779bb1325f8ed7bde4b743879..7eeb1ee5e17ad7e5644f8bc8a18ca967b108475d 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 74bc25a491ac01cb725d1c004197e48727c30230..d3311f0cd06f2b151c3567735eb41b5baf72e102 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - std::string(v2_reader_->key()) /* full var's name */, + string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; + if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = std::string(v2_reader_->key()); + string key(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 4de1300a7f66a8b4eb8074819432fd7dd597bb15..91654c8d4fb8067ae1fb525ebaa6c54689085545 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_CHECKPOINT_READER_H -#define TENSORFLOW_C_CHECKPOINT_READER_H +#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_ +#define TENSORFLOW_C_CHECKPOINT_READER_H_ #include #include @@ -79,4 +79,4 @@ class CheckpointReader { } // namespace checkpoint } // namespace tensorflow -#endif // TENSORFLOW_C_CHECKPOINT_READER_H +#endif // TENSORFLOW_C_CHECKPOINT_READER_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 37be52f57d865c1e59611540d5dab04b59e89444..3ee31a6a7ac641bbd3fc4c05568b61e433a1d523 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -68,7 +68,10 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/deepmind/courier:__pkg__", + "//tensorflow:internal", + ], deps = [ ":c_api", "//tensorflow/c:c_api", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc old mode 100644 new mode 100755 index a0a44440c891c4b9bd6d43299e0ececa25a6b709..0bf3d9542b72ecff916986ab809e8793b796d14c --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts( } tensorflow::Status UpdateTFE_ContextWithServerDef( - const tensorflow::ServerDef& server_def, TFE_Context* ctx) { + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, server_def, remote_eager_workers.get(), - ctx->context.Async(), &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); @@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - ctx->context.InitializeRemote( - std::move(server), std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, r, device_mgr); + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -241,8 +244,8 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, } void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, - unsigned char async) { - options->async = async; + unsigned char enable) { + options->async = enable; } void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { @@ -250,9 +253,9 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( } TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, - unsigned char async, + unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(async); + status->status = ctx->context.SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -270,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::IntraProcessRendezvous(device_mgr.get()); return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, std::move(device_mgr), r); + opts->async, device_mgr.release(), + /*device_mgr_owned*/ true, r); +} + +TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, + TF_Session* sess, TF_Status* status) { + const tensorflow::DeviceMgr* device_mgr = nullptr; + status->status = sess->session->LocalDeviceManager(&device_mgr); + if (!status->status.ok()) return nullptr; + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, /*device_mgr_owned*/ false, + r); } void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } @@ -288,6 +304,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status) { @@ -297,7 +314,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - status->status = UpdateTFE_ContextWithServerDef(server_def, ctx); + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( @@ -357,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { return result; } +int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } + tensorflow::int64 result; + status->status = h->handle->NumElements(&result); + return result; +} + int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { @@ -381,6 +410,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( @@ -536,6 +578,13 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } +void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, + TF_Status* status) { + tensorflow::Tensor t; + status->status = TF_TensorToTensor(tensor, &t); + if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t); +} + void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { @@ -719,6 +768,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } + +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h old mode 100644 new mode 100755 index 25cf7adbc737411e93afe13a69850435994a1cd2..6323f8a053197bb7069acf2d43214fb78c36f436 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -76,7 +76,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // Sets the default execution mode (sync/async). Note that this can be // overridden per thread using TFE_ContextSetAsyncForThread. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, - unsigned char async); + unsigned char enable); TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -114,7 +114,7 @@ TFE_ContextGetDevicePlacementPolicy(TFE_Context*); // Overrides the execution mode (sync/async) for the current thread. TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, - unsigned char async, + unsigned char enable, TF_Status* status); // A tensorflow.ServerDef specifies remote workers (in addition to the current @@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, // If the following is set, all servers identified by the // ServerDef must be up when the context is created. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, const void* proto, size_t proto_len, TF_Status* status); @@ -162,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, + TF_Status* status); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, @@ -170,6 +173,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by @@ -304,6 +313,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, + const char* attr_name, + TF_Tensor* tensor, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, @@ -380,6 +394,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops which rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a5c0681e2e4eddae08954d9d0178ca96a3f8f29a..104d52430cf7aa14d4d2a335a1b96e667f21ce87 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,15 +62,14 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, - bool async, - std::unique_ptr device_mgr, - tensorflow::Rendezvous* rendezvous) + TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, + tensorflow::Rendezvous* rendezvous) : context(opts, static_cast( default_policy), - async, std::move(device_mgr), rendezvous) {} + async, device_mgr, device_mgr_owned, rendezvous) {} tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 00a0a71fca5537bb65c76cb39c080c59160c5960..55331022b9dbd0696928fa44430f340f371432ac 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); @@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); @@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const char remote_device_name[] = @@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { ASSERT_TRUE(s.ok()) << s.error_message(); ASSERT_TRUE(worker_server->Start().ok()); - TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Create a new tensor_handle. @@ -1471,4 +1471,86 @@ void BM_ReadVariable(int iters) { } BENCHMARK(BM_ReadVariable); +TEST(CAPI, StringAttributes) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + std::vector values(4, 1); + TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size()); + TFE_OpSetAttrIntList(op, "strides", values.data(), values.size()); + + const int BUFFER_SIZE = 10; + char buffer[BUFFER_SIZE]; + std::strncpy(buffer, "VALID", BUFFER_SIZE); + TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer)); + // Overwriting value in "buffer", should be fine since TFE_Op + // shouldn't be holding on to it. + std::strncpy(buffer, "NHWC", BUFFER_SIZE); + TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer)); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1adb0458c35193117b5fa5cfe9ceffbaaf699af7..41b5b8ff36e16100e349cb909dc79d90fa4866b0 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,15 +29,8 @@ limitations under the License. namespace tensorflow { namespace eager { -// Information about a tensor. -struct TapeTensor { - int64 id; // Expected to be unique in the lifetime of this process. - DataType dtype; - TensorShape shape; -}; - // Represents an entry in the tape. -template +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; @@ -57,8 +50,8 @@ struct OpTapeEntry { using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. -template -using OpTape = gtl::FlatMap>; +template +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap>; // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle // specialization, which is blocked by quite a few things needing to loop back // into python now. -template +template class VSpace { public: virtual ~VSpace() {} @@ -93,10 +86,10 @@ class VSpace { gtl::ArraySlice gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(const TapeTensor& tensor) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Ones(const TapeTensor& tensor) const = 0; // Calls the passed-in backward function. virtual Status CallBackwardFunction( @@ -114,7 +107,7 @@ class VSpace { // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. -template +template class GradientTape { public: // If `persistent` is true, GradientTape will not eagerly delete backward @@ -134,7 +127,7 @@ class GradientTape { void Watch(int64 tensor_id); void RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, @@ -146,17 +139,18 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status ComputeGradient(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, - gtl::ArraySlice output_gradients, - std::vector* result); + Status ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); bool IsPersistent() const { return persistent_; } private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) { } } -template -bool GradientTape::ShouldRecord( +template +bool GradientTape::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { CHECK_EQ(tensor_ids.size(), dtypes.size()); @@ -201,14 +195,15 @@ bool GradientTape::ShouldRecord( return false; } -template -void GradientTape::Watch(int64 tensor_id) { +template +void GradientTape::Watch( + int64 tensor_id) { tensor_tape_.emplace(tensor_id, -1); } -template -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, +template +void GradientTape::RecordOperation( + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, @@ -229,16 +224,18 @@ void GradientTape::RecordOperation( for (const TapeTensor& o : output_tensors) { // Note: the tensor can have already been watched and hence be in the tape, // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; tensors.push_back(o); } - op_tape_[op_id] = OpTapeEntry{ - op_type, tensors, ids, backward_function, backward_function_deleter}; + op_tape_[op_id] = OpTapeEntry{ + op_type, std::move(tensors), ids, backward_function, + backward_function_deleter}; } -template -void GradientTape::DeleteTrace(int64 tensor_id) { +template +void GradientTape::DeleteTrace( + int64 tensor_id) { auto it = tensor_usage_.find(tensor_id); if (it == tensor_usage_.end()) { return; @@ -261,7 +258,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { auto op_it = op_tape_.find(op_id); CHECK(op_it != op_tape_.end()); for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { // Found a usage for an output, so cannot delete the op. return; } @@ -304,9 +301,9 @@ void GradientTape::DeleteTrace(int64 tensor_id) { namespace { -template +template struct BackpropInitialState { - OpTape op_tape; + OpTape op_tape; // Map from tensor ID to how many references still exist for this tensor in // the tape. @@ -322,17 +319,17 @@ struct BackpropInitialState { // If `persistent_tape` is false, op_tape is cleared and backwards functions // not needed for gradient computation are deleted. Backwards functions that // are needed, are copied and returned in BackpropInitialState. -template -BackpropInitialState PrepareBackprop( +template +BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, const gtl::FlatSet& sources_set, - bool persistent_tape) { + OpTape* op_tape, + const gtl::FlatSet& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { tensor_stack.push_back(t); } - BackpropInitialState result; + BackpropInitialState result; while (!tensor_stack.empty()) { int64 tensor_id = tensor_stack.back(); tensor_stack.pop_back(); @@ -383,9 +380,9 @@ BackpropInitialState PrepareBackprop( return result; } -template +template std::vector InitialStack( - const OpTape& op_tape, + const OpTape& op_tape, const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { @@ -396,13 +393,13 @@ std::vector InitialStack( return result; } -template -Status InitialGradients(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, - const TensorTape& tensor_tape, - const OpTape& op_tape, - gtl::FlatMap>* result) { +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -416,11 +413,10 @@ Status InitialGradients(const VSpace& vspace, } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { + if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); + vspace.Ones(op_it->second.output_tensor_info[j])); break; } } @@ -440,6 +436,27 @@ Status InitialGradients(const VSpace& vspace, return Status::OK(); } +// TODO(agarwal): use an automatic mechanism for handling None arguments to +// gradient functions. +// +// Some gradient functions can accept None arguments for gradients. The +// following maps the operation name to the indices at which the corresponding +// gradient function can accept None values. e.g. FusedBatchNorm outputs 5 +// values and hence receives 5 gradient values during backprop. However the +// gradient function uses only the first of those values and ignores the rest. +// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient +// corresponding to index 0 is used, and the gradient values at indices 1-4 are +// ignored (and hence can be None). The backprop algorithm can then leverage +// this by not constructing zeros to pass for those indices. +gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = new gtl::FlatMap>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + } // namespace // If over kMinAggregateCount gradients are accumulated and the total @@ -448,16 +465,16 @@ Status InitialGradients(const VSpace& vspace, constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; -template -Status GradientTape::ComputeGradient( - const VSpace& vspace, +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, gtl::ArraySlice target_tensor_ids, gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); - BackpropInitialState state = PrepareBackprop( + BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); @@ -485,10 +502,6 @@ Status GradientTape::ComputeGradient( VLOG(1) << " " << t; } } - gtl::FlatMap> functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -505,18 +518,16 @@ Status GradientTape::ComputeGradient( out_gradients.reserve(trace.output_tensor_info.size()); bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; + const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); + out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); } } else { any_gradient_nonzero = true; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 8486b585c8587e18e8eea18a893fac0a40ff4a27..247236b760dd8c07bbb08426100b6a4d34296d2e 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); @@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { return result; } -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status) { +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status) { tensorflow::CppShapeInferenceResult::HandleData handle_data; if (!handle_data.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 4bcb5bde62c8a4df4e68c1ce0daaf459434ceb5d..5cce84020bc68d912d259f51512341eb5f464a2c 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); void ExtendSession(TF_Session* session, TF_Status* status); // Returns the serialized CppShapeInferenceResult::HandleData proto for -// `output` if its a resource tensor, or otherwise returns the empty string. -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // Sets `output` based on `proto`, which should be a serialized -// CppShapeInferenceResult::HandleData proto. +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. // NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string // because I couldn't get SWIG to work otherwise. -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status); +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status); } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 86e687df205617018d94c19ac34fdc3bf54dcc6f..7661a01de4afcefbb66b33a05534e22d2ba1baa0 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H -#define TENSORFLOW_C_TF_STATUS_HELPER_H +#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_ +#define TENSORFLOW_C_TF_STATUS_HELPER_H_ #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/status.h" @@ -29,4 +29,4 @@ Status StatusFromTF_Status(const TF_Status* tf_status); } // namespace tensorflow -#endif // TENSORFLOW_C_TF_STATUS_HELPER_H +#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 588a45ea43f90c4d9b3d04fea305d2c562ae1d72..b587e63227708427e7fae47f8f4a7b524d963ed9 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", + "cc_library_with_android_deps", "tf_cc_binary", + "tf_cc_test", "tf_copts", "tf_gen_op_wrappers_cc", - "cc_library_with_android_deps", + "transitive_hdrs", ) cc_library( @@ -379,9 +380,11 @@ tf_cc_test( srcs = ["gradients/math_grad_test.cc"], deps = [ ":cc_ops", + ":client_session", ":grad_op_registry", ":grad_testutil", ":gradient_checker", + ":gradients", ":math_grad", ":testutil", "//tensorflow/core:lib_internal", @@ -626,7 +629,6 @@ tf_cc_binary( copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", @@ -715,3 +717,26 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +transitive_hdrs( + name = "headers", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cc_ops", + ":client_session", + ":coordinator", + ":gradient_checker", + ":gradients", + ":ops", + ":queue_runner", + ":remote_fused_graph_ops", + ":scope", + "//tensorflow/cc/profiler", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:reader", + "//tensorflow/cc/saved_model:signature_constants", + "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/cc/tools:freeze_saved_model", + ], +) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index dfdef88945deca376368edd6f7aa322b1e1cbf94..a32d1b1eb50fc715084f5ee663a732770db1883c 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -466,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return std::string(name); + return string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, @@ -508,15 +508,6 @@ bool HasOptionalAttrs( return false; } -const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { - for (int i = 0; i < api_def.in_arg_size(); ++i) { - if (api_def.in_arg(i).name() == name) { - return &api_def.in_arg(i); - } - } - return nullptr; -} - struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index a085e1d6e2de5ad63d11eb8979ae64c26b91366f..0717e7dd4b358d6c212070374bcc3fd2f91ed0ab 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -150,7 +150,7 @@ class Input { Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); - if (t.NumElements() != v.size()) { + if (t.NumElements() != static_cast(v.size())) { status = errors::InvalidArgument( "Cannot construct a tensor with ", t.NumElements(), " from an initializer list with ", v.size(), " elements"); diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8c886f31711eb014fb9e9d600c9c78cf22073f71..7f6ac4cae78d8d6e118837fce9ae5270336cdc89 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -225,7 +225,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(std::string(s)); + current_constraints.emplace(s); } } } else { diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index b353accddcb6db9a07c112de03ead2f02c4ee6a6..e9173227aadbf86eab666e6c17bacacb92888572 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index d09275b6487b4212aa35a0476002f2bb587fa210..f41de3dc2098df55fbbb616557f264a4e70db6b6 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 35a01e0341cb08c9b314908b6dcd76fd99c1e68b..1329b568ab8d4cc5cc5eed554e74bf1100d9bdcf 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -441,6 +441,21 @@ Status RealDivGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); +Status DivNoNanGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto x_1 = ConjugateHelper(scope, op.input(0)); + auto x_2 = ConjugateHelper(scope, op.input(1)); + // y = x_1 / x_2 + // dy/dx_1 = 1/x_2 + // dy/dx_2 = -x_1/x_2^2 + auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2); + auto gx_2 = Mul(scope, grad_inputs[0], + DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2)); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); +} +REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad); + Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -1007,6 +1022,26 @@ Status ProdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Prod", ProdGrad); +Status SegmentSumGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // The SegmentSum operation sums segments of the Tensor that have the same + // index in the segment_ids parameter. + // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1] + // will produce [2 + 3 + 4, 5] = [9, 5] + // The gradient that will flow back to the gather operation will look like + // [x1, x2], it will have the same shape as the output of the SegmentSum + // operation. The differentiation step of the SegmentSum operation just + // broadcast the gradient in order to retrieve the z's shape. + // dy/dz = [x1, x1, x1, x2] + grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1))); + + // stop propagation along segment_ids + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1c9bdff5e1295135abe60c282d565c39071fd78a..c16938322c3555939ace1013f3bb95c5689b503e 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -31,6 +33,7 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; +using ops::DivNoNan; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -42,6 +45,7 @@ using ops::Placeholder; using ops::Pow; using ops::Prod; using ops::RealDiv; +using ops::SegmentSum; using ops::SquaredDifference; using ops::Sub; using ops::Sum; @@ -850,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) { RunTest({x}, {x_shape}, {y}, {x_shape}); } +TEST_F(NaryGradTest, DivNoNan) { + { + TensorShape x_shape({3, 2, 5}); + const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large + // division errors in the numeric estimator used by the gradient checker. + const auto y = DivNoNan( + scope_, x, Add(scope_, Const(scope_, 1), Abs(scope_, x))); + RunTest({x}, {x_shape}, {y}, {x_shape}); + } + { + // Return 0 gradient (rather than NaN) for division by zero. + const auto x = Placeholder(scope_, DT_FLOAT); + const auto zero = Const(scope_, 0.0); + const auto y = DivNoNan(scope_, x, zero); + + std::vector grad_outputs; + TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs)); + ClientSession session(scope_); + std::vector grad_result; + TF_EXPECT_OK( + session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 3); + EXPECT_EQ(grad_result[0].flat()(0), 0.0f); + EXPECT_EQ(grad_result[0].flat()(1), 0.0f); + EXPECT_EQ(grad_result[0].flat()(2), 0.0f); + } +} + TEST_F(NaryGradTest, SquaredDifference) { TensorShape x1_shape({3, 2, 5}); TensorShape x2_shape({2, 5}); @@ -898,5 +932,14 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, SegmentSum) { + TensorShape x_shape({3, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = SegmentSum(scope_, x, {0, 0, 1}); + // the sum is always on the first dimension + TensorShape y_shape({2, 4}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 98be66a6add67a8053e286521e564286cdb8ef8d..c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -148,7 +148,7 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(main_op_name)}, nullptr /* outputs */, &run_metadata, session); } return Status::OK(); @@ -170,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored."; + "were restored. File does not exist: " + << variables_index_path; return Status::OK(); } const string variables_path = @@ -181,12 +182,12 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_path_tensor.scalar()() = variables_path; std::vector> inputs = { - {variable_filename_const_op_name.ToString(), variables_path_tensor}}; + {string(variable_filename_const_op_name), variables_path_tensor}}; AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(restore_op_name)}, nullptr /* outputs */, &run_metadata, session); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index d2f803bd18b38ad5c1a8b5afd70531db117826ea..6c29f09cde7ee17c11cb44ce48d8e9128daae4d0 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -32,7 +32,6 @@ cc_library( deps = [ ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -48,12 +47,16 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -70,6 +73,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@llvm//:support", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep ], @@ -98,6 +102,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -187,11 +192,13 @@ cc_library( srcs = ["embedded_protocol_buffers.cc"], hdrs = ["embedded_protocol_buffers.h"], deps = [ - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", "@llvm//:target", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 8dbe1e11b7c392cca29fc8792d3cf9f1bf44f1fb..b17bc658fa06b9feb7edb292bd89ef31e6309169 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -19,23 +19,27 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" -#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tfcompile { namespace { +using BufferInfo = cpu_function_runtime::BufferInfo; + bool IsAlpha(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); } @@ -85,27 +89,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) { return Status::OK(); } -// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. -size_t total_buffer_bytes(const intptr_t* sizes, size_t n) { - size_t total = 0; - for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { - total += sizes[i]; - } - } - return total; +// Returns the sum of the size of each buffer in `buffer_infos`. +size_t TotalBufferBytes(const std::vector& buffer_infos) { + return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0}, + [](size_t size, const BufferInfo& buffer_info) { + return size + buffer_info.size(); + }); } -// Fills in arg_sizes with the byte size of each positional arg. -Status ComputeArgSizes(const CompileResult& compile_result, - std::vector* arg_sizes) { - const xla::ProgramShape& ps = compile_result.program_shape; - for (int i = 0; i < ps.parameters_size(); ++i) { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } - return Status::OK(); +// Returns a vector of BufferInfo instances in `buffer_infos` that are entry +// parameter buffers. +std::vector ExtractEntryParamBufferInfos( + const std::vector& buffer_infos) { + std::vector result; + std::copy_if(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(result), [](const BufferInfo& buffer_info) { + return buffer_info.is_entry_parameter(); + }); + return result; +} + +// Returns a vector of BufferInfo instances in `buffer_infos` that are temp +// buffers. +std::vector ExtractTempBufferInfos( + const std::vector& buffer_infos) { + std::vector result; + std::copy_if(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(result), [](const BufferInfo& buffer_info) { + return buffer_info.is_temp_buffer(); + }); + return result; } // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs @@ -122,14 +135,14 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, indices = "[0]"; } else { for (int dim = 0; dim < shape.dimensions_size(); ++dim) { - dim_vars.push_back(strings::StrCat("size_t dim", dim)); - dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]"); - indices += strings::StrCat("[dim", dim, "]"); + dim_vars.push_back(absl::StrCat("size_t dim", dim)); + dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); + indices += absl::StrCat("[dim", dim, "]"); } } - rewrites->push_back({"{{I}}", strings::StrCat(i)}); + rewrites->push_back({"{{I}}", absl::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); - rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); return Status::OK(); @@ -145,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, // text-templating mechanism. string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { - str_util::ReplaceAllPairs(&code, rewrites); - return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); + absl::StrReplaceAll(rewrites, &code); + absl::StrReplaceAll({{"{{NAME}}", name}}, &code); + return code; } // Generate methods for args (inputs). @@ -180,7 +194,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, arg_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites); } @@ -221,7 +235,7 @@ Status GenResultMethods(const tf2xla::Config& config, result_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites); } @@ -278,6 +292,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { return Status::OK(); } +// Returns a list of C++ expressions that, when executed, will construct the +// BufferInfo instances in `buffer_infos`. +std::vector BufferInfosToCppExpression( + const std::vector& buffer_infos) { + std::vector buffer_infos_as_strings; + std::transform(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(buffer_infos_as_strings), + [](const BufferInfo& buffer_info) { + std::pair encoded = buffer_info.Encode(); + string encoded_second_as_str = + encoded.second == ~0ULL + ? "~0ULL" + : absl::StrCat(encoded.second, "ULL"); + return absl::StrCat( + "::tensorflow::cpu_function_runtime::BufferInfo({", + encoded.first, "ULL, ", encoded_second_as_str, "})"); + }); + return buffer_infos_as_strings; +} } // namespace Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, @@ -286,40 +319,46 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); - const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); - if (result_index < 0 || result_index >= temp_sizes.size()) { + const std::vector& buffer_infos = + compile_result.aot->buffer_infos(); + const std::vector arg_index_table = + ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); + std::vector buffer_infos_as_strings = + BufferInfosToCppExpression(buffer_infos); + if (result_index < 0 || result_index >= buffer_infos.size()) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", - temp_sizes.size(), ")"); + buffer_infos.size(), ")"); } // Compute sizes and generate methods. - std::vector arg_sizes; - TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes)); + std::vector buffer_infos_for_args = + ExtractEntryParamBufferInfos(buffer_infos); + std::vector buffer_infos_for_temps = + ExtractTempBufferInfos(buffer_infos); const xla::ProgramShape& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); - const std::vector iarg(arg_sizes.begin(), arg_sizes.end()); - const std::vector itemp(temp_sizes.begin(), temp_sizes.end()); - const size_t arg_bytes_aligned = - cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size()); - const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); - const size_t temp_bytes_aligned = - cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size()); - const size_t temp_bytes_total = - total_buffer_bytes(itemp.data(), itemp.size()); + const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_args.data(), buffer_infos_for_args.size(), + /*allocate_entry_params=*/true); + const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args); + const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_temps.data(), buffer_infos_for_temps.size(), + /*allocate_entry_params=*/true); + const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps); // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { - ns_start += strings::StrCat("namespace ", n, " {\n"); + ns_start += absl::StrCat("namespace ", n, " {\n"); } ns_start += "\n"; string ns_end("\n"); for (int i = opts.namespaces.size() - 1; i >= 0; --i) { const string& n = opts.namespaces[i]; - ns_end += strings::StrCat("} // end namespace ", n, "\n"); + ns_end += absl::StrCat("} // end namespace ", n, "\n"); } // Generate metadata. @@ -343,8 +382,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // calling HloProfilePrinter::profile_counters_size. const string assign_profile_counters_size = opts.gen_hlo_profile_printer_data - ? "data->profile_counters_size = " - "data->hlo_profile_printer_data->profile_counters_size();" + ? "data->set_profile_counters_size(" + "data->hlo_profile_printer_data()->profile_counters_size());" : ""; // Use a poor-man's text templating mechanism; first populate the full header @@ -414,9 +453,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static constexpr size_t kNumArgs = {{ARG_NUM}}; // Byte size of each argument buffer. There are kNumArgs entries. - static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}}; - return kArgSizes; + static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { + return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -424,17 +462,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->raw_function = {{ENTRY}}; - data->arg_sizes = ArgSizes(); - data->num_args = kNumArgs; - data->temp_sizes = TempSizes(); - data->num_temps = kNumTemps; - data->result_index = kResultIndex; - data->arg_names = StaticArgNames(); - data->result_names = StaticResultNames(); - data->program_shape = StaticProgramShape(); - data->hlo_profile_printer_data = StaticHloProfilePrinterData(); - {{ASSIGN_PROFILE_COUNTERS_SIZE}} + data->set_raw_function({{ENTRY}}); + data->set_buffer_infos(BufferInfos()); + data->set_num_buffers(kNumBuffers); + data->set_arg_index_table(ArgIndexToBufferIndex()); + data->set_num_args(kNumArgs); + data->set_result_index(kResultIndex); + data->set_arg_names(StaticArgNames()); + data->set_result_names(StaticResultNames()); + data->set_program_shape(StaticProgramShape()); + data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); +{{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); return *kStaticData; @@ -482,17 +520,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {{METHODS_RESULT}} private: - // Number of result and temporary buffers for the compiled computation. - static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Number of buffers for the compiled computation. + static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; + + static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::tensorflow::cpu_function_runtime::BufferInfo + kBufferInfos[kNumBuffers] = { +{{BUFFER_INFOS_AS_STRING}} + }; + return kBufferInfos; + } - // Byte size of each result / temporary buffer. There are kNumTemps entries. - static const intptr_t* TempSizes() { - static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}}; - return kTempSizes; + static const ::tensorflow::int32* ArgIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { +{{ARG_INDEX_TABLE}} + }; + return kArgIndexToBufferIndex; } + // The 0-based index of the result tuple in the temporary buffers. + static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() {{ARG_NAMES_CODE}} @@ -520,15 +568,15 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { )"; // The replacement strategy is naive, but good enough for our purposes. const std::vector> rewrites = { - {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, - {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)}, + {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, - {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, - {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, + {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, + absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, @@ -542,24 +590,25 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, - {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, + {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, - {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, - {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, - {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, - {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; - str_util::ReplaceAllPairs(header, rewrites); + {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, + {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)}, + {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())}, + {"{{BUFFER_INFOS_AS_STRING}}", + absl::StrJoin(buffer_infos_as_strings, ",\n")}}; + absl::StrReplaceAll(rewrites, header); return Status::OK(); } static string CreateUniqueIdentifier(const CodegenOpts& opts, - StringPiece suffix) { + absl::string_view suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { - strings::StrAppend(&result, "_", n); + absl::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_", suffix); + absl::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -570,7 +619,8 @@ Status GenerateMetadata(const CodegenOpts& opts, if (opts.gen_program_shape) { program_shape = - tensorflow::MakeUnique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); + // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save // space. @@ -628,7 +678,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, return Status::OK(); } -Status ValidateCppIdent(StringPiece ident, StringPiece msg) { +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { if (ident.empty()) { return errors::InvalidArgument("empty identifier: ", msg); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 83f2d3ee11d09d66f16d7ecdc11945ebe994a82a..90410c46a8e36e44454f1219ad76d0fb0937070d 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { namespace tfcompile { @@ -96,7 +96,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. -Status ValidateCppIdent(StringPiece ident, StringPiece msg); +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 29bc9c13b889c86c2ba8776c7b067c54cb05bc43..bb288d23000527be74f01630d20bbf82e50007ce 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -32,9 +32,11 @@ namespace tensorflow { namespace tfcompile { namespace { -void ExpectErrorContains(const Status& status, StringPiece str) { +using ::tensorflow::cpu_function_runtime::BufferInfo; + +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } @@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) { fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); CompileResult compile_result; - compile_result.aot.reset( - new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); + compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( + {}, + {BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0), + BufferInfo::MakeTempBuffer(2), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), + BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, + 5, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 6641d45e83020f4144616a6a2837c844330298f5..e4d8a02877c75fa72c5747650ab9c7ac229955b3 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. - static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; - return kArgSizes; + static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { + return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -75,17 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->raw_function = entry_point; - data->arg_sizes = ArgSizes(); - data->num_args = kNumArgs; - data->temp_sizes = TempSizes(); - data->num_temps = kNumTemps; - data->result_index = kResultIndex; - data->arg_names = StaticArgNames(); - data->result_names = StaticResultNames(); - data->program_shape = StaticProgramShape(); - data->hlo_profile_printer_data = StaticHloProfilePrinterData(); - + data->set_raw_function(entry_point); + data->set_buffer_infos(BufferInfos()); + data->set_num_buffers(kNumBuffers); + data->set_arg_index_table(ArgIndexToBufferIndex()); + data->set_num_args(kNumArgs); + data->set_result_index(kResultIndex); + data->set_arg_names(StaticArgNames()); + data->set_result_names(StaticResultNames()); + data->set_program_shape(StaticProgramShape()); + data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + return data; }(); return *kStaticData; @@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } private: - // Number of result and temporary buffers for the compiled computation. - static constexpr size_t kNumTemps = 6; - // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = 5; + // Number of buffers for the compiled computation. + static constexpr size_t kNumBuffers = 6; + + static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::tensorflow::cpu_function_runtime::BufferInfo + kBufferInfos[kNumBuffers] = { +::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) + }; + return kBufferInfos; + } - // Byte size of each result / temporary buffer. There are kNumTemps entries. - static const intptr_t* TempSizes() { - static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120}; - return kTempSizes; + static const ::tensorflow::int32* ArgIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { +1, 3 + }; + return kArgIndexToBufferIndex; } + // The 0-based index of the result tuple in the temporary buffers. + static constexpr size_t kResultIndex = 5; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { static const char* kNames[] = {"myfeed", nullptr}; diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 4e27aafec7747655d8e4ea3ddd1788d495ca0710..3c32d533f63f202fc9409f36709e0d29d1d7e002 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" @@ -26,8 +28,6 @@ limitations under the License. #include "llvm/Support/TargetRegistry.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/tf2xla/str_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/util.h" @@ -38,11 +38,11 @@ using xla::llvm_ir::AsStringRef; static void AddEmbeddedProtocolBufferToLlvmModule( llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, - StringPiece unique_identifier, string* protobuf_array_symbol_name, + absl::string_view unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); *protobuf_array_symbol_name = - strings::StrCat(unique_identifier, "_protobuf_array_contents"); + absl::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); llvm::Constant* protobuf_array_initializer = @@ -55,9 +55,9 @@ static void AddEmbeddedProtocolBufferToLlvmModule( protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); } -static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, - StringPiece protobuf_array_symbol_name, - int64 protobuf_array_size) { +static string CreateCPPShimExpression( + absl::string_view qualified_cpp_protobuf_name, + absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) { string code = "[]() {\n" " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" @@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, " return proto;\n" " }()"; - str_util::ReplaceAllPairs( - &code, + return absl::StrReplaceAll( + code, { - {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, - {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, - {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, + {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)}, + {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)}, + {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)}, }); - return code; } static StatusOr CodegenModule(llvm::TargetMachine* target_machine, @@ -94,10 +93,10 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, } static StatusOr> -GetTargetMachineFromTriple(StringPiece target_triple) { +GetTargetMachineFromTriple(absl::string_view target_triple) { std::string error; std::string normalized_triple = - llvm::Triple::normalize(AsStringRef(target_triple)); + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); const llvm::Target* target = llvm::TargetRegistry::lookupTarget(normalized_triple, error); if (target == nullptr) { @@ -105,20 +104,20 @@ GetTargetMachineFromTriple(StringPiece target_triple) { error.c_str()); } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( normalized_triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions(), llvm::None)); } StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed) { + absl::string_view target_triple, + absl::Span protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; std::unique_ptr module_with_serialized_proto = - MakeUnique("embedded_data_module", llvm_context); + absl::make_unique("embedded_data_module", llvm_context); EmbeddedProtocolBuffers result; @@ -136,8 +135,8 @@ StatusOr CreateEmbeddedProtocolBuffers( protobuf_to_embed.qualified_cpp_protobuf_name, protobuf_array_symbol_name, protobuf_array_size); - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); + cpp_variable_decl = + absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];"); } else { cpp_shim = "nullptr"; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 4e194a6aba9a9efcad27c47c42e148d8e537ae68..cf5c04ac4bdff73b76a365c346f7db60ce2d8197 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,8 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -83,8 +83,8 @@ struct ProtobufToEmbed { // is stored in the object_file_data field in the returned // EmbeddedProtocolBuffers instance. StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed); + absl::string_view target_triple, + absl::Span protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc index 6b098049cbd7539a2b2e2696b13139a8a6b28e0f..5deb47d12310d24dce847227bd119249210ffb8d 100644 --- a/tensorflow/compiler/aot/test.cc +++ b/tensorflow/compiler/aot/test.cc @@ -51,11 +51,9 @@ namespace tensorflow { namespace tfcompile { namespace { -void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) { - for (int i = 0; i < n; ++i) { - if (sizes[i] != -1) { - memset(bufs[i], 0, sizes[i]); - } +void zero_buffers(XlaCompiledCpuFunction* computation) { + for (int i = 0; i < computation->num_args(); ++i) { + memset(computation->arg_data(i), 0, computation->arg_size(i)); } } @@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(&computation); EXPECT_TRUE(computation.Run()); } @@ -80,7 +78,7 @@ void BM_NAME(int iters) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(&computation); testing::StartTiming(); while (--iters) { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 0ecc3feeb6fef1dd691ab2785b3221075a79ba88..10fa33ab5e84dcbc1629bee6214e8969046f19c2 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -25,6 +25,7 @@ test_suite( ":test_graph_tfmatmul_test", ":test_graph_tfmatmulandadd_test", ":test_graph_tfsplits_test", + ":test_graph_tftop_k_test", ":tfcompile_test", ], ) @@ -42,6 +43,7 @@ py_binary( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:session", "//tensorflow/python:training", @@ -66,8 +68,14 @@ genrule( "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", + "test_graph_tftop_k.pb", ], - cmd = "$(location :make_test_graphs) --out_dir $(@D)", + # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any + # GPUs which might be present. This is important because builds may run + # concurrently with tests, and tests need to be able to assume that they + # have control of the full GPU. + cmd = "CUDA_VISIBLE_DEVICES='' " + + "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], tools = [":make_test_graphs"], ) @@ -187,6 +195,9 @@ tf_library( cpp_class = "MatMulAndAddCompWithProfiling", enable_xla_hlo_profiling = True, graph = "test_graph_tfmatmulandadd.pb", + tags = [ + "manual", + ], ) tf_library( @@ -200,6 +211,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tftop_k", + testonly = 1, + config = "test_graph_tftop_k.config.pbtxt", + cpp_class = "TopKComp", + graph = "test_graph_tftop_k.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -218,13 +240,16 @@ tf_cc_test( ":test_graph_tfmatmulandadd", ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", + ":test_graph_tftop_k", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9ec7df163b1425f917e9ec51559efad3e6f05e75..de135d7a2346cdda13a8e35315929b17fa1ccbc1 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app from tensorflow.python.training import saver as saver_lib @@ -142,6 +143,12 @@ def tfsplits(_): array_ops.identity(y, name='result') +def tftop_k(_): + x = array_ops.placeholder(dtypes.int32, shape=[5], name='x') + output = nn_ops.top_k(x, 2, name='values') + array_ops.identity(output[1], name='indices') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -163,6 +170,7 @@ def main(_): write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) + write_graph(tftop_k, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b4ac2d7cbb517be841932b1cfae9e28decdf8d3 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt @@ -0,0 +1,13 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x" } + shape { + dim { size: 5 } + } +} +fetch { + id { node_name: "values" } +} +fetch { + id { node_name: "indices" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index fee46280e9a0e7ba2cf7c3ed46469ae8cc0841d4..f10852c7850f61bfd8b99fa9f1648202d182085e 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #define EIGEN_USE_CUSTOM_THREAD_POOL +#include "absl/strings/str_split.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" @@ -28,11 +29,12 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -44,8 +46,8 @@ using ::testing::IsSupersetOf; TEST(TFCompileTest, Add) { AddComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); - EXPECT_EQ(add.arg1_data(), add.args()[1]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); + EXPECT_EQ(add.arg1_data(), add.arg_data(1)); add.arg0() = 1; add.arg1() = 2; @@ -67,10 +69,10 @@ TEST(TFCompileTest, Add) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 123); EXPECT_EQ(add_const.arg0_data()[0], 123); - EXPECT_EQ(add_const.arg0_data(), add.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add.arg_data(0)); EXPECT_EQ(add_const.arg1(), 456); EXPECT_EQ(add_const.arg1_data()[0], 456); - EXPECT_EQ(add_const.arg1_data(), add.args()[1]); + EXPECT_EQ(add_const.arg1_data(), add.arg_data(1)); EXPECT_EQ(add_const.result0(), 579); EXPECT_EQ(add_const.result0_data()[0], 579); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -85,8 +87,8 @@ TEST(TFCompileTest, Add_SetArg) { int32 arg_y = 32; add.set_arg0_data(&arg_x); add.set_arg1_data(&arg_y); - EXPECT_EQ(add.arg0_data(), add.args()[0]); - EXPECT_EQ(add.arg1_data(), add.args()[1]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); + EXPECT_EQ(add.arg1_data(), add.arg_data(1)); EXPECT_TRUE(add.Run()); EXPECT_EQ(add.error_msg(), ""); @@ -97,7 +99,7 @@ TEST(TFCompileTest, Add_SetArg) { TEST(TFCompileTest, AddWithCkpt) { AddWithCkptComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); add.arg0() = 1; EXPECT_TRUE(add.Run()); @@ -117,7 +119,7 @@ TEST(TFCompileTest, AddWithCkpt) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0_data()[0], 111); - EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -125,7 +127,7 @@ TEST(TFCompileTest, AddWithCkpt) { TEST(TFCompileTest, AddWithCkptSaver) { AddWithCkptSaverComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); add.arg0() = 1; EXPECT_TRUE(add.Run()); @@ -145,7 +147,7 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0_data()[0], 111); - EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -153,9 +155,9 @@ TEST(TFCompileTest, AddWithCkptSaver) { TEST(TFCompileTest, Cond) { CondComp cond; - EXPECT_EQ(cond.arg0_data(), cond.args()[0]); - EXPECT_EQ(cond.arg1_data(), cond.args()[1]); - EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + EXPECT_EQ(cond.arg0_data(), cond.arg_data(0)); + EXPECT_EQ(cond.arg1_data(), cond.arg_data(1)); + EXPECT_EQ(cond.arg2_data(), cond.arg_data(2)); cond.arg1() = 10; cond.arg2() = 20; { @@ -178,8 +180,8 @@ TEST(TFCompileTest, Cond) { TEST(TFCompileTest, Gather) { GatherComp gather; - EXPECT_EQ(gather.arg0_data(), gather.args()[0]); - EXPECT_EQ(gather.arg1_data(), gather.args()[1]); + EXPECT_EQ(gather.arg0_data(), gather.arg_data(0)); + EXPECT_EQ(gather.arg1_data(), gather.arg_data(1)); // Successful gather. { @@ -202,12 +204,12 @@ TEST(TFCompileTest, Gather) { EXPECT_EQ(gather_const.arg0(i), params[i]); EXPECT_EQ(gather_const.arg0_data()[i], params[i]); } - EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]); + EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0)); for (int i = 0; i < 2; ++i) { EXPECT_EQ(gather_const.arg1(i), indices[i]); EXPECT_EQ(gather_const.arg1_data()[i], indices[i]); } - EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]); + EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1)); for (int i = 0; i < 2; ++i) { EXPECT_EQ(gather_const.result0(i), results[i]); EXPECT_EQ(gather_const.result0_data()[i], results[i]); @@ -222,8 +224,8 @@ TEST(TFCompileTest, MatMul2) { foo::bar::MatMulComp matmul; matmul.set_thread_pool(&device); - EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); - EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); + EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1)); // Test using the argN() methods. { @@ -271,12 +273,12 @@ TEST(TFCompileTest, MatMul2) { EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]); EXPECT_EQ(matmul_const.arg0_data()[i], args[i]); } - EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]); + EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0)); for (int i = 0; i < 6; ++i) { EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]); EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]); } - EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul_const.result0_data()[i], results[i]); @@ -300,8 +302,8 @@ TEST(TFCompileTest, MatMul2_SetArg) { float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; matmul.set_arg0_data(&arg0); matmul.set_arg1_data(&arg1); - EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); - EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); + EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1)); EXPECT_TRUE(matmul.Run()); EXPECT_EQ(matmul.error_msg(), ""); @@ -319,8 +321,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { MatMulAndAddComp muladd; muladd.set_thread_pool(&device); - EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]); - EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]); + EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0)); + EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1)); // Test methods with positional args and results. { @@ -346,12 +348,12 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg0_data()[i], args[i]); } - EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]); + EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]); } - EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]); + EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result0_data()[i], results0[i]); @@ -387,12 +389,12 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]); } - EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]); + EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]); } - EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]); + EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]); @@ -407,8 +409,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { TEST(TFCompileTest, Function) { // The function is equivalent to an addition FunctionComp add_fn; - EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); - EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); + EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0)); + EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1)); add_fn.arg0() = 1; add_fn.arg1() = 2; @@ -447,12 +449,36 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, TopK) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + TopKComp fn; + + fn.set_thread_pool(&device); + // x = [4, 1, 4, 4, 3] + fn.arg0(0) = 4; + fn.arg0(1) = 1; + fn.arg0(2) = 4; + fn.arg0(3) = 4; + fn.arg0(4) = 3; + + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); + const int32 expected_values[] = {4, 4}; + const int32 expected_indices[] = {0, 2}; + EXPECT_EQ(expected_values[0], fn.result0(0)); + EXPECT_EQ(expected_values[1], fn.result0(1)); + EXPECT_EQ(expected_indices[0], fn.result1(0)); + EXPECT_EQ(expected_indices[1], fn.result1(1)); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. AssertComp assert; - EXPECT_EQ(assert.arg0_data(), assert.args()[0]); - EXPECT_EQ(assert.arg1_data(), assert.args()[1]); + EXPECT_EQ(assert.arg0_data(), assert.arg_data(0)); + EXPECT_EQ(assert.arg1_data(), assert.arg_data(1)); assert.arg0() = 2; assert.arg1() = 1; @@ -543,24 +569,28 @@ TEST(TFCompileTest, HloProfiling) { string hlo_profile_as_string = xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), /*clock_rate_ghz=*/1.0); - VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; + + // Strip away identifier details from the profile string to avoid this test + // being a change detector for xla internals. Identifiers such as '%dot.0.7' + // just become '%dot'. + RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1"); + VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string; std::vector hlo_profile_lines = - tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + absl::StrSplit(hlo_profile_as_string, '\n'); auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto add_profile_line = HasSubstr( - "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto tuple_profile_line = HasSubstr( - "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); - auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, " + "f32[2,2]{1,0} %add)"); + auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 326f73b975aec3a7a6bc7cdc9a92f540ad545ad6..859c84bb91657422b830255b0217f8946d351458 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -105,12 +105,18 @@ def tf_library( freeze_file = freeze_name + ".pb" # First run tfcompile to generate the list of out_nodes. + # + # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we + # launch from using any GPUs which might be present. This is important + # because builds may run concurrently with tests, and tests need to be + # able to assume that they have control of the full GPU. out_nodes_file = "out_nodes_" + freeze_name native.genrule( name = ("gen_" + out_nodes_file), srcs = [config], outs = [out_nodes_file], - cmd = ("$(location " + tfcompile_tool + ")" + + cmd = ("CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), tools = [tfcompile_tool], @@ -142,9 +148,12 @@ def tf_library( out_nodes_file, ] + freeze_saver_srcs, outs = [freeze_file], - cmd = ("$(location " + - "//tensorflow/python/tools:freeze_graph)" + - freeze_args), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args + ), tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) @@ -177,16 +186,19 @@ def tf_library( metadata_object_file, function_object_file, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, @@ -216,14 +228,17 @@ def tf_library( outs = [ session_module_pb, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, @@ -258,6 +273,7 @@ def tf_library( "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 839e1588b7be6c91cf30c87bbaf75402446bd169..b95b063348c5cdfdcaed635ba527e9f0bfd6092d 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -32,9 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +56,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (str_util::EndsWith(fname, ".pbtxt")) { + if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) { for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } - std::cout << str_util::Join(nodes, ","); + std::cout << absl::StrJoin(nodes, ","); return Status::OK(); } @@ -91,8 +92,9 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, - StringPiece(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR( + WriteStringToFile(env, flags.out_function_object, + absl::string_view(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d3238c6a5efbf01a1e1b9e7a1bb8130055464b4d..4e184729efeb4a7f11810afe2f1c48bb75c33e4a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -50,7 +51,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -62,7 +63,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda([ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), @@ -76,7 +77,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -94,7 +95,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep @@ -111,7 +112,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep @@ -128,11 +129,11 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -160,6 +161,7 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:util", @@ -178,6 +180,7 @@ cc_library( "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:fifo_queue", + "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", @@ -186,6 +189,10 @@ cc_library( "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:generator_dataset_op", + "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:prefetch_dataset_op", + "@com_google_absl//absl/memory", ], ) @@ -230,6 +237,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/memory", ], ) @@ -258,6 +266,7 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, @@ -272,12 +281,13 @@ cc_library( deps = [ ":common", ":compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -298,23 +308,75 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "resource_operation_safety_analysis", + srcs = ["resource_operation_safety_analysis.cc"], + hdrs = ["resource_operation_safety_analysis.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "resource_operation_safety_analysis_test", + srcs = ["resource_operation_safety_analysis_test.cc"], + deps = [ + ":common", + ":resource_operation_safety_analysis", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) cc_library( name = "compilation_passes", srcs = [ - "build_xla_launch_ops_pass.cc", + "build_xla_ops_pass.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", + "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", + "mark_for_compilation_pass_test_helper.cc", + "partially_decluster_pass.cc", ], hdrs = [ - "build_xla_launch_ops_pass.h", + "build_xla_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", + "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", + "mark_for_compilation_pass_test_helper.h", + "partially_decluster_pass.h", ], deps = [ ":common", @@ -322,11 +384,10 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", - "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -338,6 +399,9 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -346,11 +410,14 @@ cc_library( srcs = ["xla_cluster_util.cc"], hdrs = ["xla_cluster_util.h"], deps = [ + ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -393,7 +460,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -412,18 +479,25 @@ tf_cc_test( size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", + "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", + "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":xla_cluster_util", + ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -432,6 +506,9 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -448,7 +525,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -502,6 +579,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "@com_google_absl//absl/strings", ], ) @@ -512,6 +590,9 @@ tf_cuda_cc_test( ":common", ":xla_cluster_util", ":xla_fusion_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:graph", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -519,6 +600,44 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "node_matchers", + testonly = True, + srcs = ["node_matchers.cc"], + hdrs = ["node_matchers.h"], + deps = [ + "//tensorflow/cc:ops", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "node_matchers_test", + srcs = ["node_matchers_test.cc"], + deps = [ + ":node_matchers", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:ops", + "//tensorflow/core:test_main", + ], +) + +tf_custom_op_py_library( + name = "xla_ops_py", + kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], + visibility = [ + ":friends", + ], + deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc deleted file mode 100644 index b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/public/version.h" - -namespace tensorflow { - -static Status BuildLaunchNode( - const string& nodename, const string& function_name, - const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, int num_resources, - const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, - Graph* graph, Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("XlaLaunch"); - def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Nresources", num_resources, &def); - AddNodeAttr("Tresults", result_dtypes, &def); - NameAttrList function; - function.set_name(function_name); - *function.mutable_attr() = function_attr; - AddNodeAttr("function", function, &def); - - Status status; - *node = graph->AddNode(def, &status); - return status; -} - -static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { - VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; - - int num_constant_args, num_resource_args; - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); - - if (num_constant_args < 0 || num_resource_args < 0 || - num_constant_args + num_resource_args > node->num_inputs()) { - return errors::InvalidArgument( - "Invalid number of constant/resource arguments to XLA kernel."); - } - const int num_nonconst_args = - node->num_inputs() - num_constant_args - num_resource_args; - - DataTypeVector const_dtypes(node->input_types().begin(), - node->input_types().begin() + num_constant_args); - DataTypeVector arg_dtypes( - node->input_types().begin() + num_constant_args, - node->input_types().begin() + num_constant_args + num_nonconst_args); - - // Build a XlaLaunch operator to execute the function body. - Node* launch_node; - TF_RETURN_IF_ERROR(BuildLaunchNode( - graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, - node->output_types(), graph, &launch_node)); - launch_node->set_assigned_device_name(node->assigned_device_name()); - - // Copy incoming edges to the launch node. - for (const Edge* edge : node->in_edges()) { - if (edge->IsControlEdge()) { - graph->AddControlEdge(edge->src(), launch_node); - } else { - graph->AddEdge(edge->src(), edge->src_output(), launch_node, - edge->dst_input()); - } - } - - // Copy outgoing edges to the launch node. - std::vector out_edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : out_edges) { - Node* dst = edge->dst(); - int src_output = edge->src_output(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (edge->IsControlEdge()) { - graph->AddControlEdge(launch_node, dst); - } else { - graph->AddEdge(launch_node, src_output, dst, dst_input); - } - } - graph->RemoveNode(node); - - return Status::OK(); -} - -Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - - for (Node* n : graph->op_nodes()) { - // In all cases, only try to compile computational nodes. - if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { - continue; - } - - // Only compile nodes that are marked for compilation by the - // compilation-marking pass (via 'attr_name'). - if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); - } - } - - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph, - options.flib_def); - } - return Status::OK(); -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..13a518d0e8b97c920c5c720a34ab92abdf1908dd --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +static Status BuildXlaCompileNode( + const string& nodename, const string& function_name, + const AttrValueMap& function_attr, const string& device_name, + const DataTypeVector& constant_dtypes, int num_resources, + const DataTypeVector& arg_dtypes, Graph* graph, Node** node) { + NodeDef def; + def.set_name(graph->NewName(nodename)); + def.set_op("_XlaCompile"); + def.set_device(device_name); + AddNodeAttr("Tconstants", constant_dtypes, &def); + AddNodeAttr("Targs", arg_dtypes, &def); + AddNodeAttr("Nresources", num_resources, &def); + NameAttrList function; + function.set_name(function_name); + *function.mutable_attr() = function_attr; + AddNodeAttr("function", function, &def); + + Status status; + *node = graph->AddNode(def, &status); + return status; +} + +static Status BuildXlaRunNode(const string& nodename, const string& device_name, + const DataTypeVector& arg_dtypes, + const DataTypeVector& result_dtypes, Graph* graph, + Node** node) { + NodeDef def; + def.set_name(graph->NewName(nodename)); + def.set_op("_XlaRun"); + def.set_device(device_name); + AddNodeAttr("Targs", arg_dtypes, &def); + AddNodeAttr("Tresults", result_dtypes, &def); + + Status status; + *node = graph->AddNode(def, &status); + return status; +} + +static Status GetXlaAttrs(Node* node, int* num_constant_args, + int* num_resource_args, DataTypeVector* const_dtypes, + DataTypeVector* arg_dtypes) { + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args)); + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args)); + + if (*num_constant_args < 0 || *num_resource_args < 0 || + *num_constant_args + *num_resource_args > node->num_inputs()) { + return errors::InvalidArgument( + "Invalid number of constant/resource arguments to XLA kernel."); + } + + const int num_nonconst_args = + node->num_inputs() - *num_constant_args - *num_resource_args; + + const DataTypeVector& input_types = node->input_types(); + std::copy(input_types.begin(), input_types.begin() + *num_constant_args, + std::back_inserter(*const_dtypes)); + std::copy(input_types.begin() + *num_constant_args, + input_types.begin() + *num_constant_args + num_nonconst_args, + std::back_inserter(*arg_dtypes)); + return Status::OK(); +} + +static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node, + int prefix_to_ignore) { + for (const Edge* edge : old_node->in_edges()) { + if (edge->IsControlEdge()) { + g->AddControlEdge(edge->src(), new_node); + } else if (edge->dst_input() >= prefix_to_ignore) { + g->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input() - prefix_to_ignore); + } + } +} + +static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { + std::vector out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + Node* dst = edge->dst(); + int src_output = edge->src_output(); + int dst_input = edge->dst_input(); + g->RemoveEdge(edge); + + if (edge->IsControlEdge()) { + g->AddControlEdge(new_node, dst); + } else { + g->AddEdge(new_node, src_output, dst, dst_input); + } + } +} + +static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) { + int num_constant_args, num_resource_args; + DataTypeVector const_dtypes; + DataTypeVector arg_dtypes; + + TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args, + &const_dtypes, &arg_dtypes)); + + Node *compile_node, *run_node; + + TF_RETURN_IF_ERROR(BuildXlaCompileNode( + n->name(), n->type_string(), n->def().attr(), n->requested_device(), + const_dtypes, num_resource_args, arg_dtypes, g, &compile_node)); + + DataTypeVector arg_dtypes_with_resources = arg_dtypes; + for (int i = 0; i < num_resource_args; i++) { + arg_dtypes_with_resources.push_back(DT_RESOURCE); + } + + TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(), + arg_dtypes_with_resources, + n->output_types(), g, &run_node)); + + compile_node->set_assigned_device_name(n->assigned_device_name()); + run_node->set_assigned_device_name(n->assigned_device_name()); + + CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node, + /*prefix_to_ignore=*/0); + CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node, + /*prefix_to_ignore=*/num_constant_args); + + // The compilation_key output. + g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args); + + MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node); + g->RemoveNode(n); + + return Status::OK(); +} + +Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + + for (Node* n : graph->op_nodes()) { + // In all cases, only try to compile computational nodes. + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + continue; + } + + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + if (IsXlaCompiledKernel(*n)) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n)); + } + } + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); + } + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h similarity index 71% rename from tensorflow/compiler/jit/build_xla_launch_ops_pass.h rename to tensorflow/compiler/jit/build_xla_ops_pass.h index 1dfea93f02081404c5c3c6686a8b28a8530ae8a3..1dd38fa95186dfbe458166caa23a131fbe3c9510 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -13,19 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ -#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -class BuildXlaLaunchOpsPass : public GraphOptimizationPass { +// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and +// executes (using XLA) TF function calls marked with "_XlaCompiledKernel". +class BuildXlaOpsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index a2e6285339f9ed0bde8d72f5b4752b1ecc22f426..6f1ff85f24a4c1fd3e6d54fcff9f8868aee6f750 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -125,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, const DataTypeVector& arg_types = (*fbody)->arg_types; std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { @@ -207,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, // device memory. // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory + // in device memory except for resources. MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (int i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } // Create the kernel. NameAttrList function; @@ -223,8 +230,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - *kernel = MakeUnique(&construction, constant_arg_indices, - resource_arg_indices, function); + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function); return s; } diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index b75ab486b80e098bc0a59f9ea8cdbaa23a28fef9..73866607621cd745f6e640a14405daebf0dd9985 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -65,11 +66,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test { for (const auto& fdef : flib) { *(proto.add_function()) = fdef; } - lib_def_ = - MakeUnique(OpRegistry::Global(), proto); + lib_def_ = absl::make_unique( + OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = MakeUnique(devices_); - pflr_ = MakeUnique( + device_mgr_ = absl::make_unique(devices_); + pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 8aff87e5e620fefd30eeb902209c9bc17540f468..9128b48da3fe9dd3d85d146e16c153c1b3bebf4c 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -21,18 +22,79 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW +// ================== // // We map every output produced by each node in the TensorFlow graph (including // control dependence) into an instance of the Predicate class. Instances of // Predicate denote logical formulas and mapping a node `n` to a predicate -// `pred` implies that `n` is executed whenver `pred` is true. Then we can -// deduce mismatching liveness in the inputs to node by comparing the predicate -// those inputs are mapped to. +// `pred` implies that `n` is live whenever `pred` is true. Then we can deduce +// mismatching liveness in the inputs to node by comparing the predicate those +// inputs are mapped to. The core logic of this pass resides in creating the +// map from TensorFlow nodes to predicates. // -// Loops are handled pessimistically -- we map Merge nodes with backedges to -// uninterpreted symbols (the same kind we use to represent Switch and _Recv). -// Predicate equality has to hold over all possible assignments to these -// uninterpreted symbols. +// +// MAPPING NODES TO PREDICATES, MODULO CYCLES +// ------------------------------------------ +// +// If we ignore cycles for a moment, computing predicates is fairly +// straightforward. We traverse the graph in RPO, mapping each node to a +// predicate based on the predicates its inputs are mapped to. For instance a +// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)). +// Roughtly speaking, we abstract interpret each node on the "liveness" domain, +// where values in the domain represent if a tensor carries a dead signal or +// not. +// +// +// DEALING WITH CYCLES +// ------------------- +// +// We map Merge nodes that are the target of a backedge to AndRecurrence +// instances. An AndRecurrence with start() = S and step() = X, printed as +// {S,&,X}, *roughly* represents the infinite list of predicates +// [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate +// for Merge in a graph like: +// +// Init +// | +// v +// Merge <-----------+ +// | | +// v | +// Incr | +// | | +// v | +// Switch <- Cond | +// | | +// v (oidx: 1) | +// | | +// +---------------+ +// +// Where S is the predicate for Init and X is the predicate that asserts that +// Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff +// S is true, live on the second iteration iff "S&X" is true, live on the third +// iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would +// normally be equivalent to S&X which isn't quite what we want to represent. +// Instead we want {S,&,X} to denote the infinite list [S, S&X, +// S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is +// true on iteration 0, 1, 2 respectively. This is made more precise in the +// comment on the AndRecurrence class. +// +// The general algorithm that deals with cycles does two RPO (reverse post +// order) passes over the graph. On the first pass it assigns a symbolic +// predicate to merge nodes with backedges. On the second pass it tries to +// pattern matche the predicates for the backedges of these merges and infer an +// AndRecurrence for the merge. +// +// In other words, we do a pessimistic data flow analysis where the data-flow +// lattice has two elements, Symbolic and NonSymbolic with Symbolic > +// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to +// converge. We don't do an optimistic data flow analysis to make pattern +// matching easier: if we assigned the predicate of the initial value to the +// merge during the first pass, on the second pass the backedge may see a +// simplified value that would be difficult to pattern match. +// +// We still use symbolic predicates for merges for which we can't pattern match +// on the backedge predicate. This is conservatively correct. namespace tensorflow { @@ -42,14 +104,21 @@ namespace { // above. class Predicate { public: - enum class Kind { kAnd, kOr, kNot, kSymbol }; + enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol }; virtual string ToString() const = 0; int64 hash() const { return hash_; } + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; virtual ~Predicate() {} + // Invokes func on p and on all of its operands recursively. Does not invoke + // `func` on the same Predicate instance twice. Aborts the search if `func` + // returns true. + template + static void Visit(Predicate* p, const FunctionTy& func); + protected: explicit Predicate(int64 hash) : hash_(hash) {} @@ -60,7 +129,7 @@ class Predicate { }; int64 HashPredicateSequence(Predicate::Kind kind, - gtl::ArraySlice preds) { + absl::Span preds) { int64 hash = ::tensorflow::hash()(kind); for (Predicate* pred : preds) { hash = Hash64Combine(hash, pred->hash()); @@ -85,12 +154,15 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } - const gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -113,11 +185,14 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } - const gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -128,23 +203,62 @@ class NotPredicate : public Predicate { public: explicit NotPredicate(Predicate* operand) : Predicate(HashPredicateSequence(Kind::kNot, {operand})), - operand_(operand) {} + operands_({operand}) {} string ToString() const override { - return strings::StrCat("~", operand()->ToString()); + return absl::StrCat("~", operand()->ToString()); } Kind kind() const override { return Kind::kNot; } - Predicate* operand() const { return operand_; } + Predicate* operand() const { return operands_[0]; } + absl::Span GetOperands() const override { + return operands_; + } + + private: + std::array operands_; +}; + +// Represents an infinite list of predicates. +// +// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands +// for the list of predicates: +// +// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// +// where GenSym(, ) renames every SymbolPredicate in +// by appending to it, in effect creating a "fresh" symbol. +// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on +// subsequent iterations". +class AndRecurrencePredicate : public Predicate { + public: + explicit AndRecurrencePredicate(Predicate* start, Predicate* step) + : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), + operands_({start, step}) {} + + Predicate* start() const { return operands_[0]; } + Predicate* step() const { return operands_[1]; } + + string ToString() const override { + return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); + } + + Kind kind() const override { return Kind::kAndRecurrence; } + + absl::Span GetOperands() const override { + return operands_; + } private: - Predicate* operand_; + std::array operands_; }; // Represents an uninterpreted symbol in a logical predicate. // // Two predicates are equivalent iff they are equivalent for all assignments to -// the symbols contained in them. +// the symbols contained in them, i.e. predicates are forall qualified over +// symbols. class SymbolPredicate : public Predicate { public: explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) @@ -153,11 +267,12 @@ class SymbolPredicate : public Predicate { must_be_true_(must_be_true) {} string ToString() const override { - return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } Kind kind() const override { return Kind::kSymbol; } + absl::Span GetOperands() const override { return {}; } // If `must_be_true()` is true this SymbolPredicate represents the proposition // "tensor_id() is live and evaluates to true". @@ -179,15 +294,38 @@ class SymbolPredicate : public Predicate { } }; +template +/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { + gtl::FlatSet visited; + std::vector stack; + + stack.push_back(p); + visited.insert(p); + + while (!stack.empty()) { + Predicate* current = stack.back(); + stack.pop_back(); + bool done = func(current); + if (done) { + return; + } + for (Predicate* op : current->GetOperands()) { + if (visited.insert(op).second) { + stack.push_back(op); + } + } + } +} + // Creates and owns Predicate instances. Simplifies predicates as it creates // them. class PredicateFactory { public: - Predicate* MakeAndPredicate(gtl::ArraySlice operands) { + Predicate* MakeAndPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/true); } - Predicate* MakeOrPredicate(gtl::ArraySlice operands) { + Predicate* MakeOrPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/false); } @@ -204,6 +342,21 @@ class PredicateFactory { } } + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { + auto it = interned_and_rec_instances_.find({start, step}); + if (it != interned_and_rec_instances_.end()) { + return it->second.get(); + } + + std::unique_ptr new_pred = + Make(start, step); + Predicate* new_pred_ptr = new_pred.get(); + CHECK(interned_and_rec_instances_ + .emplace(SignatureForAndRec(start, step), std::move(new_pred)) + .second); + return new_pred_ptr; + } + Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); @@ -229,7 +382,7 @@ class PredicateFactory { new PredicateT(std::forward(args)...)); } - Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); + Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -242,8 +395,9 @@ class PredicateFactory { // for the owning pointers to predicate instances. using SignatureForAndOr = - std::pair>; + std::pair>; using SignatureForNot = Predicate*; + using SignatureForAndRec = std::pair; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -268,14 +422,16 @@ class PredicateFactory { interned_and_or_instances_; gtl::FlatMap> interned_not_instances_; + gtl::FlatMap> + interned_and_rec_instances_; gtl::FlatMap, HashSignatureForSymbol> interned_symbol_instances_; }; // Common code to create AndPredicate or OrPredicate instances. -Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, - bool is_and) { +Predicate* PredicateFactory::MakeAndOrImpl( + absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; gtl::FlatSet simplified_ops_set; @@ -288,10 +444,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, if (op->kind() == pred_kind) { // "Inline" the operands of an inner And/Or into the parent And/Or. - gtl::ArraySlice operands = - is_and ? dynamic_cast(op)->operands() - : dynamic_cast(op)->operands(); - for (Predicate* subop : operands) { + for (Predicate* subop : op->GetOperands()) { if (simplified_ops_set.insert(subop).second) { simplified_ops.push_back(subop); } @@ -329,7 +482,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, // NB! Because we'll use a non-owning reference to simplified_ops in the // key for interned_and_or_instances_ we need to be careful to std::move() // it all the way through. - gtl::ArraySlice operands_slice = simplified_ops; + absl::Span operands_slice = simplified_ops; std::unique_ptr new_pred = is_and ? Make(std::move(simplified_ops)) : Make(std::move(simplified_ops)); @@ -351,6 +504,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} Status Populate(); + Status PopulateWithReversePostOrder(absl::Span rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; gtl::FlatMap PredicateMapAsString() const; @@ -359,20 +513,40 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); - void SetPred(Node* n, int output_idx, Predicate* pred) { - CHECK( - predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second); + + // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th + // bit of `should_revisit` if `pred` is different from the current predicate + // for the `output_idx` output of `n`. + void SetPredicate(Node* n, int output_idx, Predicate* pred, + std::vector* should_revisit) { + auto insert_result = + predicate_map_.insert({TensorId(n->name(), output_idx), pred}); + if (!insert_result.second && insert_result.first->second != pred) { + VLOG(4) << "For " << n->name() << ":" << output_idx << " from " + << insert_result.first->second->ToString() << " " + << insert_result.first->second << " to " << pred->ToString() + << " " << pred; + insert_result.first->second = pred; + if (should_revisit != nullptr) { + for (const Edge* e : n->out_edges()) { + (*should_revisit)[e->dst()->id()] = true; + } + } + } } - void SetPred(Node* n, gtl::ArraySlice output_idxs, Predicate* pred) { + + void SetPredicate(Node* n, absl::Span output_idxs, Predicate* pred, + std::vector* should_revisit) { for (int output_idx : output_idxs) { - SetPred(n, output_idx, pred); + SetPredicate(n, output_idx, pred, should_revisit); } } - Status HandleSwitch(Node* n); - Status HandleMerge(Node* n); - Status HandleRecv(Node* n); - Status HandleGeneric(Node* n); + Status HandleSwitch(Node* n, std::vector* should_revisit); + Status HandleMerge(Node* n, std::vector* should_revisit); + Status HandleRecv(Node* n, std::vector* should_revisit); + Status HandleGeneric(Node* n, std::vector* should_revisit); + Status HandleNode(Node* n, std::vector* should_revisit); const Graph& graph_; gtl::FlatMap predicate_map_; @@ -395,14 +569,15 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()); + CHECK(it != predicate_map_.end()) << n->name(); incoming_preds.push_back(it->second); } } return incoming_preds; } -Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { +Status DeadnessAnalysisImpl::HandleSwitch(Node* n, + std::vector* should_revisit) { std::vector input_preds = GetIncomingPreds(n, EdgeKind::kDataAndControl); const Edge* pred_edge; @@ -414,84 +589,252 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { // Output 0 is alive iff all inputs are alive and the condition is false. input_preds.push_back(false_switch); - SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); // Output 1 is alive iff all inputs are alive and the condition is true. input_preds.push_back(true_switch); - SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); - // Control is alive iff any inputs are alive. - SetPred(n, Graph::kControlSlot, - predicate_factory_.MakeAndPredicate(input_preds)); + // Control is alive iff all inputs are alive. + SetPredicate(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } -Status DeadnessAnalysisImpl::HandleMerge(Node* n) { +namespace { +const Edge* FindUniqueBackedge(Node* merge) { + CHECK(merge->IsMerge()); + const Edge* result = nullptr; + for (const Edge* e : merge->in_edges()) { + if (e->src()->IsNextIteration()) { + CHECK_EQ(result, nullptr) + << "Multiple backedges to " << merge->DebugString(); + result = e; + } + } + return result; +} + +// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step +// does not contain `symbolic_predicate` as an inner (not top-level) operand +// then returns `Step`. Otherwise returns nullptr. +Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, + Predicate* symbolic_predicate, + Predicate* backedge_predicate) { + CHECK(dynamic_cast(symbolic_predicate)); + if (backedge_predicate->kind() != Predicate::Kind::kAnd) { + return nullptr; + } + + std::vector and_ops; + absl::Span recurrent_pred_ops = + backedge_predicate->GetOperands(); + + bool found_sym = false; + for (Predicate* and_op : recurrent_pred_ops) { + // We want the `symbol_predicate` to be the one of the operands of + // `backedge_predicate`, + if (and_op == symbolic_predicate) { + found_sym = true; + continue; + } + + // but we don't want it to be present anywhere else in the formula. E.g. we + // don't want the recurrent predicate to be + // symbol_predicate&(X|symbol_predicate). + bool found_sym_as_inner_operand = false; + auto has_self_as_inner_operand = [&](Predicate* p) { + if (p == symbolic_predicate) { + found_sym_as_inner_operand = true; + return true; // Stop searching, we're done. + } + + // Continue searching. + return false; + }; + + Predicate::Visit(and_op, has_self_as_inner_operand); + if (found_sym_as_inner_operand) { + return nullptr; + } + and_ops.push_back(and_op); + } + + return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; +} +} // namespace + +Status DeadnessAnalysisImpl::HandleMerge(Node* n, + std::vector* should_revisit) { // Merge ignores deadness of its control inputs. A merge that isn't the - // target of a backedge has is alive iff any of its data inputs are. We treat - // the liveness of a merge that is the target of a backedge symbolically. + // target of a backedge has is alive iff any of its data inputs are. The + // liveness of a merge that is the target of a backedge can sometimes be + // represented using a AndRecurrencePredicate. If neither apply, we represent + // the liveness of the merge symbolically. + + bool has_unvisited_backedge = false; + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e)); + } + } + + auto it = predicate_map_.find(TensorId(n->name(), 0)); + if (it == predicate_map_.end()) { + if (has_unvisited_backedge) { + // We're visiting this merge for the first time and it has an unvisited + // backedge. + Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); + return Status::OK(); + } - bool has_backedge = std::any_of( - n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) { - return !e->IsControlEdge() && e->src()->IsNextIteration(); - }); + // We're visiting this merge for the first time and it is a acyclic merge. + Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( + GetIncomingPreds(n, EdgeKind::kDataOnly)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); + return Status::OK(); + } - Predicate* input_data_pred = - has_backedge ? predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false) - : predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + if (it->second->kind() == Predicate::Kind::kSymbol) { + // Last time we visited this merge we only got a symbolic predicate because + // of an unvisited backedge. Try to pattern match the predicate expression + // for that backedge (which should be visited now) into an and recurrence + // for the merge node. + if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + if (Predicate* step = DeduceStepPredicate( + &predicate_factory_, it->second, + predicate_map_[InputEdgeToTensorId(unique_backedge)])) { + // If the predicate for the backedge is "Sym&X" where "Sym" is the + // predicate for the merge then the merge has predicate {S,&,X} where S + // is the predicate for the merge ignoring the backedge. + std::vector non_recurrent_inputs; + for (const Edge* e : n->in_edges()) { + if (e != unique_backedge) { + non_recurrent_inputs.push_back( + predicate_map_[InputEdgeToTensorId(e)]); + } + } - SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred); + Predicate* start = + predicate_factory_.MakeOrPredicate(non_recurrent_inputs); + Predicate* and_rec = + predicate_factory_.MakeAndRecurrencePredicate(start, step); + SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); + return Status::OK(); + } + } + } return Status::OK(); } -Status DeadnessAnalysisImpl::HandleRecv(Node* n) { +Status DeadnessAnalysisImpl::HandleRecv(Node* n, + std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. std::vector input_preds = GetIncomingPreds(n, EdgeKind::kDataAndControl); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); - SetPred(n, {0, Graph::kControlSlot}, - predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } -Status DeadnessAnalysisImpl::HandleGeneric(Node* n) { +Status DeadnessAnalysisImpl::HandleGeneric(Node* n, + std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. Predicate* pred = predicate_factory_.MakeAndPredicate( GetIncomingPreds(n, EdgeKind::kDataAndControl)); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { - SetPred(n, output_idx, pred); + SetPredicate(n, output_idx, pred, should_revisit); + } + SetPredicate(n, Graph::kControlSlot, pred, should_revisit); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleNode(Node* n, + std::vector* should_revisit) { + if (n->IsSwitch()) { + TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); + } else if (n->IsMerge()) { + TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); + } else if (n->IsControlTrigger()) { + SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), + nullptr); + } else if (n->IsRecv() || n->IsHostRecv()) { + TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit)); + } else if (n->IsNextIteration()) { + TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); + } else { + TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); } - SetPred(n, Graph::kControlSlot, pred); return Status::OK(); } Status DeadnessAnalysisImpl::Populate() { std::vector rpo; - GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/[](const Edge& edge) { return !edge.src()->IsNextIteration(); }); + return PopulateWithReversePostOrder(rpo); +} +Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( + absl::Span rpo) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. + // + // We iterate over the graph twice, each time in RPO. On the first iteration + // merge nodes with backedges are mapped to symbolic predicates. On the + // second iteration we use the predicates assigned to the backedges in the + // previous iteration to infer a more precise predicate for the backedge merge + // nodes and all the nodes that transitively use it. + // + // We don't track the output indices for should_revisit. Instead, putting a + // node in `should_revisit` denotes that the deadness flowing out from any + // output from said node may have changed. This is fine; only switches + // propagate different deadness along different output edges, and since the + // delta is solely due to the input *values* (and not input deadness), the + // delta should not change in the second iteration. + std::vector should_revisit; + should_revisit.resize(graph_.num_node_ids()); for (Node* n : rpo) { - if (n->IsSwitch()) { - TF_RETURN_IF_ERROR(HandleSwitch(n)); - } else if (n->IsMerge()) { - TF_RETURN_IF_ERROR(HandleMerge(n)); - } else if (n->IsControlTrigger()) { - SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue()); - } else if (n->IsRecv() || n->IsHostRecv()) { - TF_RETURN_IF_ERROR(HandleRecv(n)); - } else { - TF_RETURN_IF_ERROR(HandleGeneric(n)); + VLOG(4) << "Visiting " << n->name(); + TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr)); + if (n->IsNextIteration()) { + // If this is a backedge for a merge node then remember to reprocess the + // merge the next time we run. + for (const Edge* e : n->out_edges()) { + if (e->dst()->IsMerge()) { + should_revisit[e->dst()->id()] = true; + } + } + } + } + + for (Node* n : rpo) { + // The nodes added to should_revisit in the previous loop need to be + // revisited now. Reprocesing these initial nodes may add *their* consumers + // to should_revisit, and these newly added nodes will also be processed by + // this very same loop. Since we're traversing the graph in reverse post + // order (producers before consumers) and HandleNode(n) can only ever add + // n's consumers to should_revisit, we won't "miss" an addition to + // should_revisit. + if (should_revisit[n->id()]) { + VLOG(4) << "Revisiting " << n->name(); + TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit)); } } @@ -587,6 +930,15 @@ Status ComputePredicates(const Graph& graph, *out_predicate_map = impl.PredicateMapAsString(); return Status::OK(); } + +Status ComputePredicates(const Graph& graph, + absl::Span reverse_post_order, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index cdef4051108fdc5d063ab592676c7644989155bf..3df2679c629ce801fc6c9006415dcd27b40c078e 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -26,6 +26,14 @@ namespace deadness_analysis_internal { // testing purposes only. using PredicateMapTy = gtl::FlatMap; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. Makes deadness analysis visit the graph in the order +// specified in `reverse_post_order` which must be a valid RPO for the graph +// minus NextIteration->Merge edges. +Status ComputePredicates(const Graph& graph, + absl::Span reverse_post_order, + PredicateMapTy* out_predicate_map); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 6881095b51758d2e0b06c60021bc8c2860ac566e..28a56044d5e3795fc3ecf5d1092491b87cb90f01 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -32,12 +32,14 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { +using deadness_analysis_internal::ComputePredicates; +using deadness_analysis_internal::PredicateMapTy; + Status AnalyzeDeadness(Graph* graph, std::unique_ptr* result) { FixupSourceAndSinkEdges(graph); @@ -51,13 +53,73 @@ ops::Switch CreateSwitch(const Scope& root, const string& prefix) { return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate); } -Output CreateInductionVariable(const Scope& root, const string& prefix, - const string& frame_name, int32 init) { - Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init); +TensorId ControlOutputFor(const Output& o) { + return {o.node()->name(), Graph::kControlSlot}; +} + +void VLogGraphIfAsked(const Graph& graph) { + if (VLOG_IS_ON(3)) { + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + string serialized; + ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized); + LOG(INFO) << serialized; + } +} + +struct InductionVarInfo { + Output induction_var; + Output loop_cond; +}; + +// Creates an induction variable with the following structure (simplified for +// brevity): +// +// +---------------+ +// | initial_value | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Enter | +// +---------------+ +// | +// | +// v +// +---------------+ +// +> | Merge | -+ +// | +---------------+ | +// | | | +// | | | +// | v | +// | +---------------+ | +// | | LessThan10 | | +// | +---------------+ | +// | | | +// | | | +// | v | +// | +---------------+ | +// +----+- | Switch | <+ +// | | +---------------+ +// | | | +// | | | +// | | v +// | | +---------------+ +// | +- | AddOne | +// | +---------------+ +// | +---------------+ +// +-----> | Exit | +// +---------------+ +InductionVarInfo CreateInductionVariable(const Scope& root, + const string& prefix, + const string& frame_name, + const Output& initial_value) { Output enter_initial_value = ops::internal::Enter( root.WithOpName(prefix + "/enter"), initial_value, frame_name); - ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value}); + ops::Merge iv(root.WithOpName(prefix + "/iv"), + {enter_initial_value, enter_initial_value}); Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = @@ -66,16 +128,84 @@ Output CreateInductionVariable(const Scope& root, const string& prefix, ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); - Output iv_next = - ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by); + Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), + latch.output_true, increment_by); Output next_iteration = - ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next); + ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next); - root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1); + CHECK(root.graph() + ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1) + .ok()); root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return iv.output; + return {iv.output, loop_cond}; +} + +InductionVarInfo CreateInductionVariable(const Scope& root, + const string& prefix, + const string& frame_name, int32 init) { + return CreateInductionVariable( + root, prefix, frame_name, + ops::Const(root.WithOpName(prefix + "/init"), init)); +} + +// Creates an induction variable with the following structure: +// +// +---------------+ +// | initial_value | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Enter | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Merge | <+ +// +---------------+ | +// | | +// | | +// v | +// +-----------+ +---------------+ | +// | loop_cond | --> | Switch | -+ +// +-----------+ +---------------+ +// | +// | +// v +// +---------------+ +// | Exit | +// +---------------+ +struct DependentInductionVar { + Output induction_var; + ops::Switch latch; +}; + +DependentInductionVar CreateDependentLoopInvariantValue( + const Scope& root, const string& prefix, const string& frame_name, + const Output& loop_cond, const Output& value) { + Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"), + value, frame_name); + ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + Output next_iteration = ops::NextIteration( + root.WithOpName(prefix + "/next_iteration"), latch.output_true); + CHECK(root.graph() + ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1) + .ok()); + return {iv.output, latch}; +} + +DependentInductionVar CreateDependentLoopInvariantValue( + const Scope& root, const string& prefix, const string& frame_name, + const Output& loop_cond, int32 value) { + return CreateDependentLoopInvariantValue( + root, prefix, frame_name, loop_cond, + ops::Const(root.WithOpName(prefix + "/init"), value)); } TEST(DeadnessAnalysisTest, BasicPositive) { @@ -337,21 +467,224 @@ TEST(DeadnessAnalysisTest, HostRecv) { TEST(DeadnessAnalysisTest, Loop) { Scope root = Scope::NewRootScope().ExitOnError(); - Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0); - Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0); - Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1); + Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var; + Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var; + Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var; Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1); Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2); - std::unique_ptr result; - TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have // noticed that. Today we are pessimistic here because we assign an // uninterpreted symbol to merges with backedges. - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 + // produce the same deadness. But we're not that smart today. + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + EXPECT_EQ(predicate_map[ControlOutputFor(add1)], + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + } +} + +TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + Output dependent_iv0 = + CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + .induction_var; + Output dependent_iv1 = + CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + .induction_var; + Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + } +} + +TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { + // Create a merge that "looks like" a loop but isn't really. It has a value + // that does not depend on the merge on its backedge. + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + DependentInductionVar dependent_iv = + CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0); + FixupSourceAndSinkEdges(root.graph()); + + // To make deadness analysis think that dependent_iv is a loop we need an RPO + // that visits the merge before the backedge. This is a legal RPO for + // deadness analysis since it ignores NextIteration->Merge edges during RPO. + // Right now dependent_iv has an edge from Merge to NextIteration so do the + // RPO with this edge in place. Then remove this edge to get our test case. + std::vector rpo; + GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{}, + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + TF_ASSERT_OK(root.graph()->UpdateEdge( + iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0)); + + VLogGraphIfAsked(*root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], + "div0/iv:0"); + } +} + +TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "frame", 0); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "frame", + ops::internal::Enter(root.WithOpName("iv_inner/enter"), + inner_value.output_true, "frame_inner")); + + Output dependent_outer_iv0 = + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", + iv_outer.loop_cond, 0) + .induction_var; + Output dependent_outer_iv1 = + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", + iv_outer.loop_cond, 0) + .induction_var; + + Output dependent_inner_iv0 = + CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = + CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; + + Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, + dependent_inner_iv1); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], + "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," + "*iv_inner/cond:0}"); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + } +} + +TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer_0 = + CreateInductionVariable(root, "iv_outer_0", "frame", 0); + ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer_0.loop_cond); + InductionVarInfo iv_inner_0 = CreateInductionVariable( + root, "iv_inner_0", "frame", + ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), + inner_value_0.output_true, "frame_inner")); + + InductionVarInfo iv_outer_1 = + CreateInductionVariable(root, "iv_outer_1", "frame", 1); + ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer_1.loop_cond); + InductionVarInfo iv_inner_1 = CreateInductionVariable( + root, "iv_inner_1", "frame", + ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), + inner_init_value_1.output_true, "frame_inner")); + Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, + iv_inner_1.induction_var); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], + "{#true,&,*iv_outer_0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], + "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," + "*iv_inner_0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], + "{#true,&,*iv_outer_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], + "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," + "*iv_inner_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," + "*iv_inner_1/cond:0} & " + "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," + "*iv_inner_0/cond:0})"); + } } TEST(DeadnessAnalysisTest, ControlInputs) { @@ -454,9 +787,8 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - deadness_analysis_internal::PredicateMapTy predicate_map; - TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(), - &predicate_map)); + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); TensorId logical_and_output_0 = {logical_and.node()->name(), Graph::kControlSlot}; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index fdd71c6a588ad96301f543651c8531e6f9c3ca05..e0632ff7e48ccea99d469f62ec9d0a3fe8295024 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -36,6 +38,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" @@ -44,8 +47,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +void SortControlInputs(GraphDef* gdef) { + int64 num_nodes = gdef->node_size(); + for (int64 i = 0; i < num_nodes; ++i) { + NodeDef* node = gdef->mutable_node(i); + // Stable sort control inputs and leave the order of data inputs unchanged. + std::stable_sort(node->mutable_input()->begin(), + node->mutable_input()->end(), + [](const string& a, const string& b) { + bool a_is_control = absl::StartsWith(a, "^"); + bool b_is_control = absl::StartsWith(b, "^"); + return (!a_is_control && b_is_control) || + (a_is_control && b_is_control && a < b); + }); + } +} + namespace { bool AreAllParentsGuaranteedConst( @@ -755,7 +772,7 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -790,7 +807,7 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -950,16 +967,15 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", - oc_subgraph_name, "_host_compute"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), kHostComputeOp); builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); builder.Attr("ancestors", host_compute_ancestors); - builder.Attr("key", - strings::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -1017,8 +1033,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; - NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), - "NoOp"); + NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); Status s = builder.Finalize(&seq_def); @@ -1091,10 +1106,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); + dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), + *graph_, library); + dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), + fdef); } if (!reuse_existing_functions || library->Find(name) == nullptr) { @@ -1130,8 +1145,8 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shapes", shapes); } else { string inference_graph_name = - strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, - "_", outside_compilation_subgraph_name); + absl::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); FunctionDef fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); @@ -1155,14 +1170,13 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; dump_graph::DumpGraphToFile( - strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, library); dump_graph::DumpFunctionDefToFile( - strings::StrCat("replace_encapsulate_fdef_", name), fdef); + absl::StrCat("replace_encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->RemoveFunction(name)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); return Status::OK(); } @@ -1187,8 +1201,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - strings::StrCat(call_node_def_.name(), "_key_placeholder"), - "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1222,16 +1235,16 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_recv"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); builder.Device(device_); builder.Attr("Toutputs", dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); builder.Attr(group_attribute, subgraph_name); builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); @@ -1277,13 +1290,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_send"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), kSendFromHostOp); builder.Device(device_); builder.Attr("Tinputs", dtypes); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); @@ -1517,7 +1530,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { // Dump subgraphs. for (auto& entry : subgraphs_) { dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } } @@ -2053,7 +2066,7 @@ struct PathDetails { struct SubgraphAndClusterHash { inline std::size_t operator()(const SubgraphAndCluster& v) const { return hash()( - strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + absl::StrCat(v.subgraph, v.outside_compilation_cluster)); } }; @@ -2505,7 +2518,8 @@ Status EncapsulateSubgraphsPass::Run( const int num_args = input_permutation->size(); std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr)); DataTypeVector arg_types(num_args); TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 926589546fec72048485d30966f31b24e44b1245..90354a801afb26b003e00c4529069fdc61bbca32 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Sorts each node's control inputs by their names. This guarantees that for two +// structually equivalent GraphDefs, we get the same traversal ordering on +// node's control input fields. +// TODO(hpucha): Move the utilities to a more appropriate place. +void SortControlInputs(GraphDef* gdef); + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index c0543a00792235c5dd090e81930d8c219dc7f1a3..49958093b8dcf35e8adcdfd2f7dfce8558d5db6f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" @@ -25,7 +27,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, FunctionDef* fdef = library->add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( *graph, - strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + absl::StrCat("_outside_compilation_shape_inference_", name_suffix), fdef)); return Status::OK(); } @@ -65,18 +66,18 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = b.find(elt_a.first); if (iter == b.end()) { if (diff) { - *diff = strings::StrCat( - map_name, " expected: contains element with key '", - key_to_string(elt_a.first), "' got: map has no such element"); + *diff = absl::StrCat(map_name, " expected: contains element with key '", + key_to_string(elt_a.first), + "' got: map has no such element"); } return false; } if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { - *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), "' has value '", - value_to_string(elt_a.second), "' got: '", - value_to_string(iter->second), "'"); + *diff = absl::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), "' has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); } return false; } @@ -85,9 +86,9 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = a.find(elt_b.first); if (iter == a.end()) { if (diff) { - *diff = strings::StrCat(map_name, " got: contains element with key '", - key_to_string(elt_b.first), - "' expected: map has no such element"); + *diff = absl::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); } return false; } @@ -99,38 +100,38 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, const string& diff_preamble, string* diff) { if (a.op() != b.op()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected op '", a.op(), "' got '", b.op()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); } return false; } if (a.device() != b.device()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected device '", a.device(), "' got '", - b.device()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); } return false; } if (a.input_size() != b.input_size()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected ", a.input_size(), " inputs got ", - b.input_size(), " expected:\n", a.DebugString(), - "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } std::unordered_set control_input_a; std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (str_util::StartsWith(a.input(i), "^")) { - if (!str_util::StartsWith(b.input(i), "^")) { + if (absl::StartsWith(a.input(i), "^")) { + if (!absl::StartsWith(b.input(i), "^")) { if (diff) { - *diff = strings::StrCat( - diff_preamble, " mismatch for node ", a.name(), " input ", i, - ", expected control input ", a.input(i), " got ", b.input(i), - " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected control input ", + a.input(i), " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -138,19 +139,19 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, control_input_b.insert(b.input(i)); } else if (a.input(i) != b.input(i)) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " input ", i, ", expected ", a.input(i), - " got ", b.input(i), " expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), " got ", + b.input(i), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } } if (control_input_a != control_input_b) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " control inputs differ expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -170,18 +171,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return av.DebugString() == bv.DebugString(); } }, - strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), - diff); + absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Signature mismatch for function ", - a.signature().name(), ", expected:\n", - a.signature().DebugString(), "\ngot:\n", - b.signature().DebugString()); + *diff = + absl::StrCat("Signature mismatch for function ", a.signature().name(), + ", expected:\n", a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } @@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, - strings::StrCat("attr mismatch for function ", a.signature().name()), + absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } @@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const string& av, const string& bv) { return av == bv; }, - strings::StrCat("ret mismatch for function ", a.signature().name()), + absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; } @@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (a.node_def(i).name() == b.node_def(j).name()) { if (!EqualFunctionNodeDef( a.node_def(i), b.node_def(j), - strings::StrCat("Function ", a.signature().name()), diff)) { + absl::StrCat("Function ", a.signature().name()), diff)) { return false; } found = true; @@ -220,9 +220,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", expected: has node '", a.node_def(i).name(), - "' got: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); } return false; } @@ -237,9 +237,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", got: has node '", b.node_def(i).name(), - "' expected: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); } return false; } @@ -258,8 +258,8 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, auto it = actual_index.find(expected_function.signature().name()); if (it == actual_index.end()) { if (diff) { - *diff = strings::StrCat("Did not find expected function '", - expected_function.signature().name(), "'"); + *diff = absl::StrCat("Did not find expected function '", + expected_function.signature().name(), "'"); } return false; } @@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, if (!actual_index.empty()) { if (diff != nullptr) { - *diff = strings::StrCat("Found unexpected function '", - actual_index.begin()->second->signature().name(), - "'"); + *diff = + absl::StrCat("Found unexpected function '", + actual_index.begin()->second->signature().name(), "'"); } return false; } @@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, +Node* KnownShapeBase(DataType dtype, absl::Span shape, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", @@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, .FinalizeBuilder(&node_builder); } -Node* KnownShape(const gtl::ArraySlice& shape, +Node* KnownShape(absl::Span shape, const GraphDefBuilder::Options& opts) { return KnownShapeBase(DT_FLOAT, shape, opts); } @@ -417,14 +417,12 @@ Node* KeyPlaceholder(const string& call_node, } Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, - const string& oc_cluster, - const gtl::ArraySlice& dtypes, + const string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_recv"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -441,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_send"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -683,8 +680,8 @@ std::vector> GraphEdges(const Graph& graph) { for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( - strings::StrCat(edge->src()->name(), ":", edge->src_output()), - strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); + absl::StrCat(edge->src()->name(), ":", edge->src_output()), + absl::StrCat(edge->dst()->name(), ":", edge->dst_input())); } std::sort(edges.begin(), edges.end()); return edges; @@ -768,7 +765,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -813,7 +810,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -892,13 +889,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, @@ -1038,26 +1035,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1190,13 +1187,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1213,13 +1210,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1364,13 +1361,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1386,13 +1383,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F2_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"i_0_retval", "I:o:0"}}); @@ -1495,13 +1492,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1579,13 +1576,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1661,12 +1658,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1742,12 +1739,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1846,13 +1843,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -1955,13 +1952,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -2066,37 +2063,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {"outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"}}}, @@ -2272,13 +2269,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..97ef8cd3cb3fba54259fc413e0a3d3e75a89c431 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -0,0 +1,360 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { + +const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = + "_xla_compile_id"; + +namespace { + +const char* const kXlaClusterOutput = "XlaClusterOutput"; + +// Checks if a graph node is marked to be a guaranteed constant. +bool is_guaranteed_constant(const Node& n) { + bool guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) + .ok()) { + return false; + } + return guaranteed_constant; +} + +// Finds the `index` of an _Arg or _Retval node. +Status GetIndexAttr(const Node& n, int num_args, int* index) { + TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); + if (*index < 0 || *index >= num_args) { + return errors::InvalidArgument("Invalid ", n.type_string(), " number ", + *index); + } + return Status::OK(); +} + +// Returns the data type of the destination of an edge. +DataType EdgeType(const Edge* edge) { + return edge->dst()->input_type(edge->dst_input()); +} + +// Adds the control inputs of `node` to `*deps`. +void AddControlInputs(const Node& node, gtl::FlatSet* deps) { + for (const Edge* edge : node.in_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->src()); + } + } +} + +// Adds the control outputs of `node` to `*deps`. +void AddControlOutputs(const Node& node, gtl::FlatSet* deps) { + for (const Edge* edge : node.out_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->dst()); + } + } +} + +// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts +// the arguments into the order expected by XlaLaunch computations: +// 1) arguments +// 2) resource variable arguments +// See the documentation of EncapsulateSubgraphsInFunctions for the meaning +// of the arguments. +// +// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. +Status RewriteSubgraph(const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + const int num_args = input_permutation->size(); + const int num_retvals = output_permutation->size(); + + std::vector args; + std::vector retvals; + args.reserve(num_args); + retvals.reserve(num_retvals); + for (Node* n : graph->nodes()) { + if (n->type_string() == "_Arg") { + // Check if this is a guaranteed constant. + if (is_guaranteed_constant(*n)) { + return errors::InvalidArgument( + "Guaranteed constants are not supported (", n->name(), ")"); + } + args.push_back(n); + } else if (n->type_string() == "_Retval") { + retvals.push_back(n); + } + } + + if (std::find(args.begin(), args.end(), nullptr) != args.end()) { + return errors::InvalidArgument("Missing or non-consecutive arguments"); + } + + // Reorders the arguments. + std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { + // Non-resources appear before resources + bool a_is_resource = (a->output_type(0) == DT_RESOURCE); + bool b_is_resource = (b->output_type(0) == DT_RESOURCE); + // Uses the name as a tiebreaker so the output is deterministic. + StringPiece a_name(a->name()); + StringPiece b_name(b->name()); + return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); + }); + + // Sorts the retvals by name so the order is deterministic. + std::sort(retvals.begin(), retvals.end(), + [](Node* a, Node* b) { return a->name() < b->name(); }); + + // Computes the permutation to produce the correct argument order, and update + // the argument indices. + int variable_start_index = num_args; + for (int i = 0; i < num_args; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); + if (args[i]->output_type(0) == DT_RESOURCE && + variable_start_index == num_args) { + variable_start_index = i; + } + (*input_permutation)[index] = i; + args[i]->AddAttr("index", i); + } + VLOG(4) << "variable_start_index: " << variable_start_index; + + // Computes the permutation to produce the correct retval order, and update + // the argument indices. + for (int i = 0; i < num_retvals; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); + (*output_permutation)[index] = i; + retvals[i]->AddAttr("index", i); + } + + AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), + call_def); + AddNodeAttr("_variable_start_index", variable_start_index, call_def); + + // Uniquify the function name. + GraphDef gdef; + graph->ToGraphDef(&gdef); + + // Before serialization, sort each node's control inputs to achieve + // determinism. Sorting control inputs could help (but not necessarily) create + // a deterministic serialization and fingerprint. Other sources of + // nondeterminism include unstable node ordering. + SortControlInputs(&gdef); + // Fingerprint the function. + // Nondeterminism in serialization would not lead to incorrect results, but + // may cause spurious cache misses. DeterministicSerialization is a + // best-effort deterministic serialization. + string serialized; + TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); + uint64 fingerprint = Fingerprint64(serialized); + LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); + return Status::OK(); +} + +} // namespace + +/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Check for undeclared outputs before Encapsulation, so we can give a better + // error message. + // TODO(phawkins): merge this with the encapsulation code to avoid the extra + // O(n) pass over the edges. + for (const Edge* e : (*graph)->edges()) { + if (!e->IsControlEdge() && + e->src()->attrs().Find(kXlaClusterAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->dst()->type_string() != kXlaClusterOutput) { + return errors::InvalidArgument( + "Undeclared output of XLA computation. A common cause of this error " + "is variable initializers that depend on the XLA computation. Edge: ", + e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", + e->dst_input()); + } + } + + auto output = absl::make_unique((*graph)->op_registry()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, "", **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), + "EncapsulateXlaComputationsPass failed"); + graph->swap(output); + return Status::OK(); +} + +/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( + Graph* graph) { + // Finds all of the XlaLaunch function calls, to avoid mutating the graph + // while iterating. + std::vector launch_nodes; + for (Node* n : graph->nodes()) { + string name; + if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + launch_nodes.push_back(n); + } + } + + // Replaces each launch function call together with its neighboring + // XlaClusterOutput nodes with a XlaLaunch node. + for (Node* launch : launch_nodes) { + int variable_start_index; + TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", + &variable_start_index)); + + std::vector in_edges; + TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); + + const int num_inputs = in_edges.size(); + const int num_variables = num_inputs - variable_start_index; + const int num_args = variable_start_index; + + VLOG(4) << "Launch node '" << launch->name() << "'" + << " input edges: " << in_edges.size() << " num_args: " << num_args + << " num_variables: " << num_variables; + + std::vector nodes_to_remove = {launch}; + + // Data and control inputs to the new XlaLaunch node. + std::vector> data_inputs(num_inputs); + gtl::FlatSet control_inputs; + DataTypeVector arg_types(num_args); + + AddControlInputs(*launch, &control_inputs); + + for (int i = 0; i < num_args; ++i) { + const Edge* edge = in_edges[i]; + data_inputs[i] = {edge->src(), edge->src_output()}; + arg_types[i] = EdgeType(edge); + } + + // Appends the variable inputs. + for (int i = 0; i < num_variables; ++i) { + int pos = variable_start_index + i; + const Edge* edge = in_edges[pos]; + data_inputs[pos] = {edge->src(), edge->src_output()}; + } + + // Outputs. + const int num_outputs = launch->output_types().size(); + gtl::FlatSet control_outputs; + std::vector>> data_outputs(num_outputs); + DataTypeVector output_types(num_outputs); + + for (const Edge* le : launch->out_edges()) { + if (le->IsControlEdge()) { + control_outputs.insert(le->dst()); + } else { + TF_RET_CHECK(le->src_output() < num_outputs); + Node* output_node = le->dst(); + + TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) + << le->DebugString(); + nodes_to_remove.push_back(output_node); + + for (const Edge* oe : output_node->out_edges()) { + TF_RET_CHECK(!oe->IsControlEdge()); + data_outputs[le->src_output()].push_back( + {oe->dst(), oe->dst_input()}); + } + output_types[le->src_output()] = output_node->input_type(0); + + AddControlOutputs(*output_node, &control_outputs); + } + } + + NodeDef def; + def.set_name(launch->name()); + + // Target the XLA CPU/GPU backends. + VLOG(2) << "Replacing with XlaLaunch"; + def.set_op("XlaLaunch"); + AddNodeAttr("Tconstants", DataTypeVector{}, &def); + AddNodeAttr("Targs", arg_types, &def); + AddNodeAttr("Nresources", num_variables, &def); + AddNodeAttr("Tresults", output_types, &def); + NameAttrList function; + function.set_name(launch->type_string()); + AddNodeAttr("function", function, &def); + + for (Node* node : nodes_to_remove) { + VLOG(2) << "Deleting node " << node->DebugString(); + // Ensure that we do not attempt to add control edges to nodes that are + // deleted. + control_inputs.erase(node); + control_outputs.erase(node); + graph->RemoveNode(node); + } + + Status status; + Node* xla_launch = graph->AddNode(def, &status); + if (!status.ok()) { + return status; + } + for (int i = 0; i < data_inputs.size(); ++i) { + graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, + i); + } + for (Node* n : control_inputs) { + graph->AddControlEdge(n, xla_launch); + } + for (int i = 0; i < data_outputs.size(); ++i) { + for (const auto& successor : data_outputs[i]) { + graph->AddEdge(xla_launch, i, successor.first, successor.second); + } + } + for (Node* n : control_outputs) { + graph->AddControlEdge(xla_launch, n); + } + } + return Status::OK(); +} + +Status EncapsulateXlaComputationsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateXlaComputations(): " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + VLOG(1) << "EncapsulateXlaComputations() half-way: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + VLOG(1) << "EncapsulateXlaComputations() finished: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..99e9dfd598f29697dd009aa32f5317ed3dc647ae --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up an XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + + namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + static const char* const kXlaClusterAttr; // _xla_compile_id + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static Status BuildXlaLaunchOps(Graph* graph); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f643fb0cfe136caba42272d72f3972ec63a94bf3 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -0,0 +1,346 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static std::unique_ptr MakeOuterGraph( + const FunctionLibraryDefinition& flib_def, const string& function) { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); + + Status status; + Node* launch = scope.graph()->AddNode(def, &status); + TF_CHECK_OK(status); + TF_CHECK_OK(scope.DoShapeInference(launch)); + scope.graph()->AddEdge(a.node(), 0, launch, 0); + scope.graph()->AddEdge(b.node(), 0, launch, 1); + scope.graph()->AddEdge(c.node(), 0, launch, 2); + scope.graph()->AddEdge(d.node(), 0, launch, 3); + scope.graph()->AddEdge(u.node(), 0, launch, 4); + scope.graph()->AddEdge(v.node(), 0, launch, 5); + scope.graph()->AddEdge(w.node(), 0, launch, 6); + + auto out0 = + ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); + auto out1 = + ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); + auto out2 = + ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); + auto out3 = + ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +// Makes an encapsulate body graph for use in tests. +static std::unique_ptr MakeBodyGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); + auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); + + auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); + auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); + auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, arg3); + add_attrs(g.node()); + + auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), + b_identity, 0); + auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); + auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); + auto out3 = + ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { + // Test that control edge insertion order doesn't affect the cache key + // (cluster name) generated by TPU encapsulate pass. + auto get_serialized_graph = [](bool control_input_reversed, + bool operand_reversed) -> string { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); + auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); + + ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) + : ops::Add(scope.WithOpName("E"), a1, a0); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, + "launch0"); + }; + add_attrs(e.node()); + + TF_CHECK_OK(scope.ToGraph(graph.get())); + auto get_node_in_graph = [&graph](Node* node) { + return graph->FindNodeId(node->id()); + }; + // Insert control edge in different order. The order should not affect + // the encapsulated or serialized graph. + if (!control_input_reversed) { + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + } else { + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + } + } + TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + GraphDef gdef; + graph->ToGraphDef(&gdef); + // Before serialization, sort control inputs first to remove + // nondeterminism. + SortControlInputs(&gdef); + string serialized; + SerializeToStringDeterministic(gdef, &serialized); + return serialized; + }; + + // Changing the order of control input shouldn't affect the graph generated. + EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, + /*operand_reversed=*/false), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); + + // Changing the order of data input should affect the graph generated. + EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/true), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); +} + +TEST(EncapsulateXlaComputations, Encapsulate) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); + add_attrs(b_identity.node()); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), a, c); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, d); + add_attrs(g.node()); + + auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); + auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); + auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); + auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + std::unique_ptr graph_copy(new Graph(&flib_def)); + CopyGraph(*graph, graph_copy.get()); + + TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + + std::unordered_map index = BuildNodeIndex(*graph); + string function = index.at("launch0")->type_string(); + + // Tests the outer graph is as expected. + { + std::unique_ptr outer = MakeOuterGraph(flib_def, function); + GraphDef expected_def; + outer->ToGraphDef(&expected_def); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); + } + + // Tests the encapsulated body graph is as expected. + { + std::unique_ptr body = MakeBodyGraph(); + GraphDef expected_body_def; + body->ToGraphDef(&expected_body_def); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, + DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); + } + + // Encapsulates the same computation again, verifies we reuse the same + // function. Encapsulation should be deterministic to avoid recompilation. + TF_ASSERT_OK( + EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); + std::unordered_map index_copy = BuildNodeIndex(*graph_copy); + string function_copy = index_copy.at("launch0")->type_string(); + EXPECT_EQ(function, function_copy); +} + +TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { + std::unique_ptr body_graph = MakeBodyGraph(); + FunctionDefLibrary flib; + TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph = MakeOuterGraph(flib_def, "launch0"); + TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); + + Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NameAttrList function; + function.set_name("launch0"); + auto launch = ops::XlaLaunch( + scope.WithOpName("launch0"), std::initializer_list{}, + std::initializer_list{a, b, c, d}, + std::initializer_list{u, v, w}, + DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); + + auto consumer0_a = + ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); + auto consumer0_b = + ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); + auto consumer0_c = + ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); + auto consumer1 = + ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); + auto consumer2 = + ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); + auto consumer3 = + ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); + + GraphDef expected_def; + TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ(expected_def, actual_def); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 676f71a75aede2a7720ae0c8a579d64cc184509a..8212956adfeca263334e3d0d954f69e13c1ba28d 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -14,6 +14,7 @@ cc_library( hdrs = ["graphcycles.h"], deps = [ "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 805bbc62c1e2e877de87ab8faf3d60b829743df8..756377bd9502d7172b29f317c471963d1dee09a9 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -34,7 +34,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -44,7 +44,7 @@ namespace { typedef std::unordered_set NodeSet; template struct VecStruct { - typedef gtl::InlinedVector type; + typedef absl::InlinedVector type; }; template using Vec = typename VecStruct::type; diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 4d49a14b24d53bbcb434560d59b8c97a17e18f86..085c0e5adbb270e71ff3447a936555c99904e26c 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -13,25 +13,48 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { +// PRE_PLACEMENT passes: + +// EncapsulateXlaComputationsPass rewrites computations generated by the +// xla.compile() Python code into XlaLaunch nodes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, + EncapsulateXlaComputationsPass); + +// from +// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +// FunctionalizeControlFlowPass: 27 +// +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). Following passes must +// handle those FunctionDef correctly. + +// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + PartiallyDeclusterPass); + // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, - BuildXlaLaunchOpsPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, + BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 00a6f4075f9a18efc3895b033eb6d08e36088a53..0839f1cb3dafd9af533631c73a37a1df7172ac0b 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -7,15 +7,16 @@ package( ) cc_library( - name = "xla_launch_op", - srcs = ["xla_launch_op.cc"], - hdrs = ["xla_launch_op.h"], + name = "xla_ops", + srcs = ["xla_ops.cc"], + hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", @@ -25,19 +26,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", - ], - alwayslink = 1, -) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - visibility = ["//tensorflow/compiler/jit:friends"], - deps = [ - "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc deleted file mode 100644 index bd4eefbc0bb960f8ddc1d238057e73a29a098f26..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace { - -// Inputs 2*N tensors, outputs the first N inputs. -// Logs errors if input tensor i and i + N are not (near) identical -// in any position. -class ParallelCheckOp : public OpKernel { - public: - explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - template - int CompareTensors(DataType dtype, const char* v0, const char* v1, - int64 num_elts, int input_idx) { - int failed = 0; - const T* p0 = reinterpret_cast(v0); - const T* p1 = reinterpret_cast(v1); - double rtol; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), - &rtol)) { - LOG(ERROR) << "can't convert parallel_check_rtol " - << flags->parallel_check_rtol << " to double"; - } - double atol; - if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), - &atol)) { - LOG(ERROR) << "can't convert parallel_check_atol " - << flags->parallel_check_atol << " to double"; - } - for (int i = 0; i < num_elts; ++i) { - bool ok = (p0[i] == p1[i]); - VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; - if (!ok) { - if (std::is_same::value || std::is_same::value) { - float tolerance = - std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); - T diff = p0[i] - p1[i]; - if (diff < 0) diff = 0 - diff; - ok = (diff <= tolerance); - } - if (ok) continue; - LOG(ERROR) << "Op " << name() << " fails equality at output " - << input_idx << " type " << DataTypeString(dtype) - << " element " << i << ": std_val=" << p0[i] - << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); - if (++failed > 10) break; - } - } - return failed; - } - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << name(); - const int num_pairs = ctx->num_inputs() / 2; - for (int i = 0; i < num_pairs; ++i) { - CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); - Tensor t0 = ctx->input(i); - Tensor t1 = ctx->input(i + num_pairs); - int64 num_elts = t0.NumElements(); - CHECK_EQ(num_elts, t1.NumElements()); - - // Compare inputs elementwise for near-exact equality. - const char* v0 = t0.tensor_data().data(); - const char* v1 = t1.tensor_data().data(); - int failed = 0; - switch (ctx->input_dtype(i)) { - case DT_INT32: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_INT64: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_FLOAT: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_DOUBLE: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_BOOL: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - default: - LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); - } - if (failed > 0) { - LOG(ERROR) << "check failed for " << name() << " output " << i - << " num_elts: " << num_elts; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (flags->parallel_check_failfast) { - LOG(QFATAL) << "failfast on first parallel-check failure"; - } - } else { - VLOG(1) << "check passed for " << name() << " output " << i - << " num_elts: " << num_elts; - } - - // Propagate the std value. - if (IsRefType(ctx->input_dtype(i))) { - ctx->forward_ref_input_to_ref_output(i, i); - } else { - ctx->set_output(i, ctx->input(i)); - } - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); -}; - -REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), - ParallelCheckOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc deleted file mode 100644 index b313d48011b561eaab618692df49d1558c34a77c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ /dev/null @@ -1,284 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" - -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/variable_ops.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function) - : OpKernel(ctx), - constants_(constants), - resources_(resources), - device_type_(ctx->device_type()), - function_(function) { - if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id_ = se::host::kHostPlatformId; - } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = ctx->device() - ->tensorflow_gpu_device_info() - ->stream->parent() - ->platform() - ->id(); - } else { - platform_id_ = nullptr; - } -} - -Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - if (s.ok()) { - *cache = new XlaCompilationCache(metadata->client(), - metadata->jit_device_type()); - return Status::OK(); - } - - auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_); - if (!platform.ok()) { - return platform.status(); - } - xla::LocalClientOptions client_options; - client_options.set_platform(platform.ValueOrDie()); - client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); - auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); - if (!client.ok()) { - return client.status(); - } - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(), - ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - device_type_.type()); - } - *cache = new XlaCompilationCache( - client.ValueOrDie(), DeviceType(registration->compilation_device_name)); - return Status::OK(); -} - -void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOpBase::Compute " - << Canonicalize(function_.name(), AttrSlice(&function_.attr())); - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - - XlaCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [this, ctx](XlaCompilationCache** cache) { - return BuildCompilationCache(ctx, cache); - })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - - const XlaDevice::Metadata* metadata = nullptr; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - bool allocate_xla_tensors = s.ok(); - bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); - - // Get the platform_id_ for XLA_* devices. - if (platform_id_ == nullptr) { - if (s.ok()) { - platform_id_ = metadata->platform()->id(); - } - } - - std::map variables = - SnapshotResourceVariables(ctx, resources_); - - xla::LocalClient* client = static_cast(cache->client()); - - XlaAllocator local_xla_allocator(client->backend().platform(), - ctx->device()->GetAllocator({})); - xla::DeviceMemoryAllocator* xla_allocator; - // If we are on an XlaDevice, use the underlying XLA platform's allocator - // directly. We could use the StreamExecutor's allocator which may - // theoretically be more correct, but XLA returns a nice OOM message in a - // Status and StreamExecutor does not. - // - // Importantly we can't use ctx->device()->GetAllocator() as the allocator - // (which local_xla_allocator above uses) as on an XlaDevice, this is a - // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a - // real allocator to allocate real buffers. - if (allocate_xla_tensors) { - xla_allocator = client->backend().memory_allocator(); - } else { - xla_allocator = &local_xla_allocator; - } - - XlaCompiler::Options options; - options.client = client; - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); - } - options.device_type = cache->device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); - options.device_allocator = xla_allocator; - if (metadata) { - options.shape_representation_fn = metadata->shape_representation_fn(); - } - - const XlaCompiler::CompilationResult* kernel; - xla::LocalExecutable* executable; - - std::map constant_args; - for (int i : constants_) { - constant_args.insert({i, ctx->input(i)}); - } - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - // Optimization: don't resolve constants. If we resolve constants we never - // emit them on the device, meaning that if they are needed by a following - // computation the host has to transfer them. - compile_options.resolve_compile_time_constants = false; - // Optimization: where possible, have the computation return a naked array - // rather than a one-element tuple. - compile_options.always_return_tuple = false; - - OP_REQUIRES_OK( - ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, &compile_options)); - - VLOG(1) << "Executing XLA Computation..."; - - XlaComputationLaunchContext launch_context( - client, xla_allocator, allocate_xla_tensors, use_multiple_streams); - launch_context.PopulateInputs(ctx, kernel, variables); - - // Execute the computation. - VLOG(2) << "Executing computation."; - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(xla_allocator); - run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(ctx->step_id()); - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - - auto run_result = executable->Run(launch_context.arguments(), run_options); - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; - - launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); - VLOG(1) << "Done"; -} - -namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - -// Helper static functions to construct parameters for -// XlaLocalLaunchBase constructor from OpKernelConstruction. -std::vector ConstantsVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - std::vector constants(constant_types.size()); - std::iota(constants.begin(), constants.end(), 0); - return constants; -} - -std::vector ResourcesVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - - DataTypeVector arg_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Targs", &arg_types)); - - int num_resources; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Nresources", &num_resources)); - - std::vector resources(num_resources); - std::iota(resources.begin(), resources.end(), - constant_types.size() + arg_types.size()); - return resources; -} - -NameAttrList FunctionAttr(OpKernelConstruction* ctx) { - const NameAttrList* func; - OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); - return *func; -} - -#undef OP_REQUIRES_OK_RETURN -} // namespace - -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), - FunctionAttr(ctx)) {} - -XlaLocalLaunchOp::~XlaLocalLaunchOp() { - VLOG(1) << "XlaLocalLaunchOp destroyed"; -} - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch") - .Device(DEVICE_GPU) - .HostMemory("constants") - .HostMemory("resources"), - XlaLocalLaunchOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h deleted file mode 100644 index 8dfc4b382d51151b6383fe7dd75429f3124d39be..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ - -#include "tensorflow/compiler/jit/xla_compilation_cache.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. -// The only difference is that it does not require arguments to follow -// the "constants, then regular args, then resources" order. -// It takes vectors of constant and resource arguments explicitly. -// It does not have corresponding OpDef because it is never present -// in the GraphDef. -// Currently, it is used by eager runtime. FunctionLibraryRuntime creates -// this kernel when asked to create a kernel for an XLA-compiled function. -class XlaLocalLaunchBase : public OpKernel { - public: - XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function); - XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; - XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; - ~XlaLocalLaunchBase() override = default; - - void Compute(OpKernelContext* ctx) override; - - protected: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache); - - // Indexes of compile-time constant inputs - std::vector constants_; - // Indexes of resource inputs - std::vector resources_; - - DeviceType device_type_; - NameAttrList function_; - se::Platform::Id platform_id_; -}; - -// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph -// which will be compiled and executed using XLA. The XlaLocalLaunchOp is -// responsible for handling interactions with the TensorFlow executor. -// Once all inputs are present, and their shapes are known, the op can -// use a 'XlaCompilationCache' to compile and execute code which is specific -// to the shapes of input Tensors. -// XlaLocalLaunchOp uses xla::LocalClient::Compile() and -// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device -// memory. -class XlaLocalLaunchOp : public XlaLocalLaunchBase { - public: - explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); - ~XlaLocalLaunchOp() override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a85006eb0378688dffd634c13a392b02e379c7f2 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -0,0 +1,499 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_ops.h" + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +namespace { + +Status PlatformInfoFromContext(OpKernelConstruction* ctx, + XlaPlatformInfo* result) { + DeviceType device_type = ctx->device_type(); + se::Platform::Id platform_id = nullptr; + const XlaDevice::Metadata* xla_device_metadata = nullptr; + std::unique_ptr xla_allocator; + xla::DeviceMemoryAllocator* device_allocator = nullptr; + + if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + platform_id = se::host::kHostPlatformId; + } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { + platform_id = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which xla_allocator above uses) as on an XlaDevice, this is a dummy + // allocator that returns XlaTensor objects. The XlaCompiler needs a real + // allocator to allocate real buffers. + + platform_id = xla_device_metadata->platform()->id(); + device_allocator = + xla_device_metadata->client()->backend().memory_allocator(); + } + + if (!device_allocator) { + TF_ASSIGN_OR_RETURN(se::Platform* const platform, + se::MultiPlatformManager::PlatformWithId(platform_id)); + xla_allocator = absl::make_unique( + platform, ctx->device()->GetAllocator({})); + } + + *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); + + return Status::OK(); +} + +// A closure describing how to run a compiled version of a TensorFlow function. +// +// It may seem unusual to stick the resource variable snapshots in this class. +// This is necessary: we need to use the snapshots observed by the compiler as +// the initial values for the resource variables (and cannot snapshot them again +// during execution) because otherwise we risk observing a different snapshot +// with shapes different from what we compiled for. +class XlaExecutableClosure { + public: + explicit XlaExecutableClosure( + xla::LocalClient* client, xla::LocalExecutable* executable, + const XlaCompiler::CompilationResult* compilation_result, + std::map resource_var_snapshots, + int num_constant_args) + : client_(client), + executable_(executable), + compilation_result_(compilation_result), + resource_var_snapshots_(std::move(resource_var_snapshots)), + num_constant_args_(num_constant_args) {} + + XlaExecutableClosure(XlaExecutableClosure&&) = default; + XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; + + xla::LocalClient* client() const { return client_; } + xla::LocalExecutable* executable() const { return executable_; } + const XlaCompiler::CompilationResult* compilation_result() const { + return compilation_result_; + } + const std::map& resource_var_snapshots() const { + return resource_var_snapshots_; + } + int num_constant_args() const { return num_constant_args_; } + + private: + xla::LocalClient* client_; + xla::LocalExecutable* executable_; + const XlaCompiler::CompilationResult* compilation_result_; + std::map resource_var_snapshots_; + int num_constant_args_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); +}; + +// This maintains a mapping from a globally unique ID to XlaExecutableClosure +// instances. +class XlaExecutableClosureStore { + public: + XlaExecutableClosureStore() : key_counter_(0) {} + + using KeyT = string; + + KeyT Produce(XlaExecutableClosure result) { + mutex_lock l(mutex_); + KeyT key = absl::StrCat(key_counter_++); + bool insert_successful = closures_.emplace(key, std::move(result)).second; + DCHECK(insert_successful); + (void)insert_successful; + return key; + } + + XlaExecutableClosure Consume(const KeyT& key) { + mutex_lock l(mutex_); + auto it = closures_.find(key); + DCHECK(it != closures_.end()); + XlaExecutableClosure value = std::move(it->second); + closures_.erase(it); + return value; + } + + static XlaExecutableClosureStore* Global() { + static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; + return instance; + } + + private: + mutex mutex_; + int64 key_counter_ GUARDED_BY(mutex_); + gtl::FlatMap closures_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); +}; + +} // namespace + +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + function_(function) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +static Status BuildCompilationCache(OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache) { + if (platform_info.xla_device_metadata()) { + *cache = new XlaCompilationCache( + platform_info.xla_device_metadata()->client(), + platform_info.xla_device_metadata()->jit_device_type()); + return Status::OK(); + } + + auto platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + if (!platform.ok()) { + return platform.status(); + } + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); + if (!client.ok()) { + return client.status(); + } + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), + ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + platform_info.device_type().type()); + } + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); + return Status::OK(); +} + +static Status CompileToLocalExecutable( + OpKernelContext* ctx, const NameAttrList& function, + const XlaPlatformInfo& platform_info, absl::Span resources, + absl::Span constants, xla::LocalClient** client, + std::map* variables, + const XlaCompiler::CompilationResult** kernel, + xla::LocalExecutable** executable) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + if (!rm) { + return errors::Internal("No resource manager."); + } + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, platform_info, cache); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + *variables = SnapshotResourceVariables(ctx, resources); + *client = static_cast(cache->client()); + + XlaCompiler::Options options; + options.client = *client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } + options.device_type = cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = + (platform_info.platform_id() == se::host::kHostPlatformId); + options.device_allocator = platform_info.allocator(); + if (platform_info.xla_device_metadata()) { + options.shape_representation_fn = + platform_info.xla_device_metadata()->shape_representation_fn(); + } + + std::map constant_args; + for (int i : constants) { + constant_args.insert({i, ctx->input(i)}); + } + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + // If we resolve constants we never emit them on the device, meaning that if + // they are needed by a following computation the host has to transfer + // them. Not resolving constants is expected to be faster than resolving + // constants. + compile_options.resolve_compile_time_constants = true; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + + return cache->Compile(options, function, constant_args, *variables, ctx, + kernel, executable, compile_options); +} + +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, &client, &variables, &kernel, + &executable)); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + VLOG(1) << "Executing XLA Computation..."; + + XlaComputationLaunchContext launch_context( + client, platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + platform_info_.UseMultipleStreams()); + launch_context.PopulateInputs(ctx, kernel, variables, + /*missing_ctx_input_prefix=*/0); + + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); + VLOG(1) << "Done"; +} + +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + +XlaLocalLaunchOp::~XlaLocalLaunchOp() { + VLOG(1) << "XlaLocalLaunchOp destroyed"; +} + +XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) + : OpKernel(ctx), + constants_(ConstantsVector(ctx)), + resources_(ResourcesVector(ctx)), + function_(FunctionAttr(ctx)) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaCompileOp::Compute(OpKernelContext* ctx) { + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, &client, &variables, &kernel, + &executable)); + + // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even + // if it didn't have to compile the cluster because of a compilation-cache + // hit. This is because we at least need new snapshots of the resource + // variables. + XlaExecutableClosureStore::KeyT key = + XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( + client, executable, kernel, std::move(variables), constants_.size())); + + Allocator* cpu_allocator = [&] { + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + return ctx->device()->GetAllocator(host_alloc_attrs); + }(); + + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + compilation_key.flat()(0) = key; + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.flat()(0) = true; + + ctx->set_output(0, compilation_key); + ctx->set_output(1, compilation_successful); +} + +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaRunOp::Compute(OpKernelContext* ctx) { + Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); + const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); + + XlaExecutableClosure closure = + XlaExecutableClosureStore::Global()->Consume(key); + + XlaComputationLaunchContext launch_context( + closure.client(), platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); + + // We're missing the must-be-constant inputs, tell `PopulateInputs` + // about this. We don't actually need these inputs because they've + // already been baked into the compiled kernel. + launch_context.PopulateInputs( + ctx, closure.compilation_result(), closure.resource_var_snapshots(), + /*missing_ctx_input_prefix=*/closure.num_constant_args()); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = + closure.executable()->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; + + OP_REQUIRES_OK( + ctx, + launch_context.PopulateOutputs( + ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/closure.num_constant_args())); +} + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); +REGISTER_KERNEL_BUILDER(Name("_XlaCompile") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaCompileOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..489d26eb30a66646158f39ea3fc6f55759c7f88e --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + std::unique_ptr xla_allocator, + xla::DeviceMemoryAllocator* device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + xla_allocator_(std::move(xla_allocator)), + device_allocator_(device_allocator) { + CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr)); + } + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + xla::DeviceMemoryAllocator* allocator() const { + return device_allocator_ ? device_allocator_ : xla_allocator_.get(); + } + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator and + // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device + // then device_allocator_ is null and xla_allocator_ points to an appropriate + // XlaAllocator instance. + std::unique_ptr xla_allocator_; + xla::DeviceMemoryAllocator* device_allocator_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); +}; + +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + NameAttrList function_; + XlaPlatformInfo platform_info_; +}; + +// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaLocalLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'XlaCompilationCache' to compile and execute code which is specific +// to the shapes of input Tensors. +// XlaLocalLaunchOp uses xla::LocalClient::Compile() and +// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device +// memory. +class XlaLocalLaunchOp : public XlaLocalLaunchBase { + public: + explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); + ~XlaLocalLaunchOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); +}; + +class XlaCompileOp : public OpKernel { + public: + explicit XlaCompileOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + NameAttrList function_; + + XlaPlatformInfo platform_info_; +}; + +class XlaRunOp : public OpKernel { + public: + explicit XlaRunOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + XlaPlatformInfo platform_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5b6692f523658749f7ef48f9d7d89e97d4ce8b09..07c5b2318851ed506711b9ee00c66fe680a3afd8 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -28,18 +28,6 @@ cc_library( ], ) -cc_library( - name = "parallel_check_op_flags", - srcs = ["parallel_check_op_flags.cc"], - hdrs = ["parallel_check_op_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "xla_device_flags", srcs = ["xla_device_flags.cc"], diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc deleted file mode 100644 index a61694b49407b923b7c83f35e903ef49a2175f0e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ParallelCheckOpFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ParallelCheckOpFlags; - flags->parallel_check_failfast = true; - flags->parallel_check_atol = "1e-5"; - flags->parallel_check_rtol = "1e-5"; - flag_list = new std::vector({ - Flag("parallel_check_failfast", &flags->parallel_check_failfast, - "Fail immediately on first parallel-check comparison error."), - Flag("parallel_check_atol", &flags->parallel_check_atol, - "Absolute error tolerance for parallel-check comparison."), - Flag("parallel_check_rtol", &flags->parallel_check_rtol, - "Relative error tolerance for parallel-check comparison."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h deleted file mode 100644 index 156a2a2a71097631e24d154b102cd9b85a990b3a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// parallel_check_op module. -typedef struct { - bool parallel_check_failfast; // Fail immediately on first parallel-check - // comparison error. - string parallel_check_atol; // Absolute error tolerance for parallel-check - // comparison. - string parallel_check_rtol; // Relative error tolerance for parallel-check - // comparison. -} ParallelCheckOpFlags; - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 45d422943c23f59823e6bfbcb355d4b58a6a225e..133d9823609efa688d8a8f7a066ccbfefc75c15b 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -27,7 +27,9 @@ limitations under the License. #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,7 +41,9 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -65,57 +69,83 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. if (HasForwardedRefInput(node)) { + VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; return false; } return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } +bool HasResourceOutput(const Node& node) { + return std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +bool HasResourceInput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end(); +} + +// Returns true if `node` is a resource operation recognized by tf2xla that +// operates on something other than resource variables. +bool IsNonResourceVarResourceOp(const Node& node) { + // TODO(b/112837194): We can't cluster these because we only support + // snapshotting resource variables (and we can't e.g. snapshot stacks). This + // limitation may be fixable with some work. + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() != XlaResourceKind::kVariable; +} + // Make sure we don't recurse infinitely on recursive functions. const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime); // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. bool IsCompilableWhile(const Node& while_node, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_node.type_string(); - const NameAttrList* name_attr; NodeDef call; Status status; status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'cond' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'cond' attribute on While node."; return false; } const string cond_func = name_attr->name(); call.set_name("while_cond"); call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop condition: " << cond_func; + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop condition: " << cond_func; return false; } status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'body' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'body' attribute on While node."; return false; } const string body_func = name_attr->name(); call.set_name("while_body"); call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop body: " << body_func; + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop body: " << body_func; return false; } - VLOG(2) << "Loop is compilable."; return true; } @@ -123,12 +153,12 @@ bool IsCompilableWhile(const Node& while_node, // Every operator in the function must be compilable for a function to be // compilable. bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Function marking: " << call_def.op(); - if (depth > kMaxRecursionDepth) { - VLOG(2) << "Function depth limit exceeded"; + VLOG(2) << "Rejecting " << call_def.op() + << ": function depth limit exceeded."; return false; } @@ -136,9 +166,14 @@ bool IsCompilableCall(const NodeDef& call_def, Status status = lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { - VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; + VLOG(2) << "Rejecting " << call_def.op() + << ": could not instantiate: " << status; return false; } + + auto release_handle_on_return = gtl::MakeCleanup( + [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); }); + const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); CHECK(fbody); const FunctionDef& fdef = fbody->fdef; @@ -150,7 +185,8 @@ bool IsCompilableCall(const NodeDef& call_def, // tf2xla to translate the TF graph into XLA. So we avoid this for now. // // TODO(b/36139787): Create a mechanism to set inlining hints. - VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + VLOG(2) << "Rejecting " << call_def.op() + << ": can't compile noinline function."; return false; } @@ -158,29 +194,25 @@ bool IsCompilableCall(const NodeDef& call_def, if (node->type_string() == "_Arg" || node->type_string() == "_Retval") continue; if (node->type_string() == "While") { - // Handle functional While loop (not in open source build). - return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); + // Handle functional While loop. + return IsCompilableWhile(*node, jit_device_type, allow_resource_ops, + depth + 1, lib_runtime); + } + if (!allow_resource_ops && + (HasResourceInput(*node) || HasResourceOutput(*node))) { + return false; } if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, depth + 1, - lib_runtime)) { - VLOG(2) << "Function marking failed: unsupported op " << node->name() - << ": " << node->def().ShortDebugString(); + !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops, + depth + 1, lib_runtime)) { + VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " + << node->name() << ": " << node->def().ShortDebugString(); return false; } } - VLOG(2) << "Function is compilable: " << call_def.op(); return true; } -// Tests whether `node` has a DT_RESOURCE typed input or output. -bool HasResourceInputOrOutput(const Node& node) { - return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end() || - std::find(node.output_types().begin(), node.output_types().end(), - DT_RESOURCE) != node.output_types().end(); -} - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // @@ -333,16 +365,23 @@ bool IsXlaFusable(const NodeDef& node) { return elementwise_ops->count(node.op()) > 0; } +// Nodes that XLA can compile are put in `candidates`. Nodes put in +// `isolated_nodes` must either be unclustered or be put in trivial single-node +// clusters. Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - OrderedNodeSet* candidates) { + OrderedNodeSet* candidates, gtl::FlatSet* isolated_nodes) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, flib_def, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + std::vector compile_time_const_nodes(graph.num_node_ids(), false); + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, + &compile_time_const_nodes)); int64& fuel = legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; @@ -357,24 +396,29 @@ Status FindCompilationCandidates( } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); + if (fuel >= std::numeric_limits::max() / 2) { + // The assumption is that if fuel started out as INT64_MAX, it will forever + // stay greater than INT64_MAX / 2. + VLOG(2) << "Starting fuel: infinity"; + } else { + VLOG(2) << "Starting fuel: " << fuel; + } + for (Node* node : sorted_nodes) { - VLOG(2) << "Fuel: " << fuel; if (fuel <= 0) { - VLOG(2) + VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; } - VLOG(2) << "FindCompilationCandidates(): Processing " - << node->DebugString(); - DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); + VLOG(4) << "Device type for " << node->name() << ": " + << device_type.type_string(); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - VLOG(2) << "Compilation rejected node: not compilable " << node->name() - << ": " << node->type_string(); + // is_compilable_fn has already logged the reason if it returned false. continue; } @@ -383,33 +427,93 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { - VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->type_string(); + !IsCompilableCall(node->def(), jit_device_type, + registration->compile_resource_ops, 0, lib_runtime)) { + VLOG(2) << "Rejecting " << node->name() << ": unsupported op " + << node->type_string(); continue; } if (!registration->compile_resource_ops && - HasResourceInputOrOutput(*node)) { - VLOG(2) << "Compilation rejected node: resource input/output " - << node->name() << ": " << node->type_string(); + (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { + // We don't have a way of returning values of type DT_RESOURCE from XLA + // computations so we avoid auto-clustering nodes producing DT_RESOURCE. + // XlaLaunchOp also cannot snapshot resources that are not resource + // variables so we avoid clustering resource operations that operate on + // non-resource variables. + VLOG(2) << "Rejecting: " << node->name() << ": resource output " + << node->type_string(); continue; } + if (compile_time_const_nodes[node->id()]) { + const OpDef* op_def; + TF_RETURN_IF_ERROR( + graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); + if (op_def->is_stateful()) { + // It is easiest to demonstrate the problem we're trying to solve with + // an example. Say we have this graph: + // + // shape = RandomUniformInt(); + // reshape = Reshape(input, shape) + // + // Both RandomUniformInt and Reshape are compilable by XLA so, absent + // any other reason, we will try to put both shape and reshape in the + // same cluster. However, since XLA only supports statically shaped + // values, it will expect to be able to constant fold `shape` to get a + // static shape for `reshape`. This is a problem because side-effecting + // ops like RandomUniformInt() cannot be constant folded. We fix this + // by putting `shape` and `reshape` in different clusters, which results + // in us recompiling `reshape`'s cluster for every new value of `shape`, + // making `reshape` statically sized within each compilation. We + // simplify the solution even further by disallowing operations like + // `shape` from being part of *any* non-trivial cluster. They're either + // not compiled by XLA altogether or, if assigned to an XLA_* device + // with "must compile" semantics, compiled into a trivial single-op + // cluster. This approach leaves some room for improvement, and we can + // consider implementing a more aggressive data-flow-analysis based + // solution in the future if needed. + // + // One ugly problem we have to contend with: certain sets of ops *have* + // to be in the same cluster because values flowing between them have + // types that can't be live-in or live-out of a cluster. These ops are: + // + // - TensorArray ops operating on the same TensorArray instance. + // - Stack ops operating on the same Stack instance. + // + // To work around this we avoid isolating these specific ops. Because + // of this concession it is unsound to auto-cluster them because then + // we'd create clusters we could not compile (because we can't constant + // fold, say, a TensorArrayRead or a StackPopV2). But we don't + // auto-cluster these operations today so we're good for now. + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node->type_string()); + bool is_tensor_array_or_stack_op = + op_info && op_info->resource_kind() != XlaResourceKind::kVariable; + if (!is_tensor_array_or_stack_op) { + VLOG(2) << "Isolating " << node->name() + << ": must-be-constant stateful op"; + isolated_nodes->insert(node); + // Keep going and execute all the other checks. + } + } + } + // We don't auto-cluster functional control flow nodes containing resource + // operations because safety checks are trickier in this case. + // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not + // for CPU/GPU. if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { + !IsCompilableWhile(*node, jit_device_type, + registration->compile_resource_ops, 0, + lib_runtime)) { continue; } // _Arg nodes in a top-level function represent feeds. // Do not compile them. if (node->type_string() == "_Arg") { - VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " - << node->DebugString(); continue; } // _Retval nodes in a top-level function represent fetches. // Do not compile them. if (node->type_string() == "_Retval") { - VLOG(2) << "Compilation rejected node: return value " << node->name() - << ": " << node->type_string(); continue; } candidates->insert(node); @@ -419,6 +523,31 @@ Status FindCompilationCandidates( return Status::OK(); } +// Determine the global jit level which is ON if either the +// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag +// is true. +OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( + const GraphOptimizationPassOptions& options) { + OptimizerOptions::GlobalJitLevel global_jit_level = + options.session_options->config.graph_options() + .optimizer_options() + .global_jit_level(); + if (global_jit_level == OptimizerOptions::DEFAULT) { + // To set compilation to be on by default, change the following line. + global_jit_level = OptimizerOptions::OFF; + } + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + if (flags->tf_xla_auto_jit == -1 || + (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { + // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides + // the setting in ConfigProto. + global_jit_level = + static_cast(flags->tf_xla_auto_jit); + } + return global_jit_level; +} + struct Cluster { // Identifies the node that represents this cluster in the cycle detection // graph. @@ -433,7 +562,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - return IsCompilableCall(ndef, jit_device_type, 0, flr); + + // We can always *compile* resource operations, even if we are sometimes + // unable to auto-cluster them. + const bool compile_resource_ops = true; + return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr); } Status MarkForCompilationPass::Run( @@ -441,22 +574,9 @@ Status MarkForCompilationPass::Run( // TODO(phawkins): precompute the "GetCompilationDevice" properties of each // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = - options.session_options->config.graph_options() - .optimizer_options() - .global_jit_level(); - if (global_jit_level == OptimizerOptions::DEFAULT) { - // To set compilation to be on by default, change the following line. - global_jit_level = OptimizerOptions::OFF; - } + GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); - if (flags->tf_xla_auto_jit == -1 || - (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { - // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides - // the setting in ConfigProto. - global_jit_level = - static_cast(flags->tf_xla_auto_jit); - } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; bool fusion_only = flags->tf_xla_fusion_only; @@ -475,6 +595,7 @@ Status MarkForCompilationPass::Run( const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; return false; } @@ -484,21 +605,36 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return compile; + } status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return compile; + } // If inputs to `node` can have conflicting deadness (i.e. some are alive // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; return false; } // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + VLOG(2) << "Rejecting " << node->name() + << ": not fusable op but fusion_only enabled."; return false; } @@ -506,12 +642,154 @@ Status MarkForCompilationPass::Run( // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; - return (ignore_registration || registration->enable_jit_by_default) && - global_jit_level > 0; + bool should_compile = + (ignore_registration || registration->enable_jit_by_default) && + global_jit_level != OptimizerOptions::OFF; + if (!should_compile) { + if (global_jit_level == OptimizerOptions::OFF) { + VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; + } else { + VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + } + } + return should_compile; }; return RunImpl(options, is_compilable); } +static string RatioToString(int numerator, int denominator) { + return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, + (100.0 * numerator) / denominator); +} + +static void VLogClusteringSummary(const Graph& g) { + if (!VLOG_IS_ON(2)) { + return; + } + + std::map cluster_name_to_size; + std::map> + cluster_name_to_op_histogram; + std::map unclustered_op_histogram; + int clustered_node_count = 0; + + for (Node* n : g.nodes()) { + absl::optional cluster_name = GetXlaClusterForNode(*n); + if (cluster_name) { + clustered_node_count++; + cluster_name_to_size[*cluster_name]++; + cluster_name_to_op_histogram[*cluster_name][n->type_string()]++; + } else { + unclustered_op_histogram[n->type_string()]++; + } + } + + int unclustered_node_count = g.num_nodes() - clustered_node_count; + + VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes(); + VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size " + << RatioToString(clustered_node_count, g.num_nodes()); + + for (const auto& cluster_name_size_pair : cluster_name_to_size) { + absl::string_view cluster_name = cluster_name_size_pair.first; + int size = cluster_name_size_pair.second; + VLOG(2) << " " << cluster_name << " " + << RatioToString(size, g.num_nodes()); + for (const auto& op_count_pair : + cluster_name_to_op_histogram[cluster_name]) { + VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second + << " instances"; + } + } + + if (!unclustered_op_histogram.empty()) { + VLOG(2) << " Unclustered nodes: " + << RatioToString(unclustered_node_count, g.num_nodes()); + for (const auto& pair : unclustered_op_histogram) { + VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; + } + } + + struct EdgeInfo { + absl::string_view node_name; + absl::optional cluster_name; + + absl::string_view GetClusterName() const { + return cluster_name ? *cluster_name : "[none]"; + } + + std::pair> AsPair() + const { + return {node_name, cluster_name}; + } + + bool operator<(const EdgeInfo& other) const { + return AsPair() < other.AsPair(); + } + }; + + using EdgeInfoMap = std::map>; + + EdgeInfoMap incoming_edge_infos; + EdgeInfoMap outgoing_edge_infos; + + std::set cluster_names_to_print; + + for (const Edge* e : g.edges()) { + const Node* from = e->src(); + absl::optional from_cluster_name = + GetXlaClusterForNode(*from); + + const Node* to = e->dst(); + absl::optional to_cluster_name = + GetXlaClusterForNode(*to); + + if (to_cluster_name == from_cluster_name) { + continue; + } + + if (to_cluster_name) { + incoming_edge_infos[*to_cluster_name] + [EdgeInfo{from->name(), from_cluster_name}]++; + cluster_names_to_print.insert(*to_cluster_name); + } + + if (from_cluster_name) { + outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++; + cluster_names_to_print.insert(*from_cluster_name); + } + } + + VLOG(2) << "*** Inter-Cluster edges:"; + if (cluster_names_to_print.empty()) { + VLOG(2) << " [none]"; + } + + auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name, + const EdgeInfoMap& edge_info_map, + absl::string_view desc) { + auto it = edge_info_map.find(cluster_name); + if (it != edge_info_map.end()) { + VLOG(2) << " " << it->second.size() << " " << desc << " edges"; + for (const auto& edge_info_count_pair : it->second) { + VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " " + << edge_info_count_pair.first.node_name << " # " + << edge_info_count_pair.second; + } + } else { + VLOG(2) << " No " << desc << " edges."; + } + }; + + for (absl::string_view cluster_name : cluster_names_to_print) { + VLOG(2) << " ** Cluster " << cluster_name; + print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, + "incoming"); + print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos, + "outgoing"); + } +} + // Is 'node' an operator that consumes only the shape of its input, not the // data itself? static bool IsShapeConsumerOp(const Node& node) { @@ -519,6 +797,43 @@ static bool IsShapeConsumerOp(const Node& node) { node.type_string() == "Size"; } +static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { + // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then + // ignore it during resource operation safety analysis. We need this hack + // because of two reasons: + // + // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. + // 2. We don't support live-out values of type DT_RESOURCE and live-in values + // of type DT_RESOURCE that are not resource variables. + // + // Together these imply we cannot let resource variable safety analysis + // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different + // clusters: both of them will have to be clustered because of (1) and we + // won't be able to keep the edge between the two as neither the input to the + // second XLA cluster nor the output from the first XLA cluster are supported + // because of (2). + // + // TODO(b/113100872): This can be fixed if the TensorFlow representation for + // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then + // (2) would no longer hold. + + if (n.assigned_device_name().empty()) { + *ignore = false; + return Status::OK(); + } + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n.assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *ignore = true; + } else { + *ignore = registration->compile_resource_ops; + } + return Status::OK(); +} + // Sequence number generator to ensure clusters have unique names. static std::atomic cluster_sequence_num; @@ -534,11 +849,12 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; + gtl::FlatSet isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env : Env::Default(), - is_compilable_fn, &compilation_candidates)); + is_compilable_fn, &compilation_candidates, &isolated_nodes)); if (compilation_candidates.empty()) { VLOG(2) << "No compilable candidates"; @@ -547,6 +863,8 @@ Status MarkForCompilationPass::RunImpl( GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -559,6 +877,8 @@ Status MarkForCompilationPass::RunImpl( worklist.push_back(&clusters[node->id()]); } + OptimizerOptions::GlobalJitLevel global_jit_level = + GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); @@ -579,11 +899,16 @@ Status MarkForCompilationPass::RunImpl( "Found control flow node in clustering worklist: ", node_from->type_string()); } + + if (isolated_nodes.count(node_from)) { + continue; + } + string from_scope; string to_scope; for (int to : cycles.Successors(from)) { if (to >= graph->num_node_ids()) { - // Node is a "frame" node that is present only in the cycle detection + // Node is a fictitious node that is present only in the cycle detection // graph. No clustering is possible. continue; } @@ -596,15 +921,20 @@ Status MarkForCompilationPass::RunImpl( node_to->assigned_device_name()) { continue; } + if (isolated_nodes.count(node_to)) { + continue; + } // Look for an _XlaScope on both nodes. If both nodes have a // scope and the scopes do not match, do not cluster along this - // edge. If even one of the nodes lacks an _XlaScope attribute, + // edge. This restriction is overridden if the global_jit_level is ON. If + // even one of the nodes lacks an _XlaScope attribute, // then it is treated as a "bridge" and a cluster may be created // along it. We may want to restrict this behavior to require // all nodes marked with _XlaCompile=true to also have a // _XlaScope property set (and raise an error otherwise); but // for now we don't do this. - if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + if (global_jit_level == OptimizerOptions::OFF && + GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && from_scope != to_scope) { continue; @@ -652,6 +982,11 @@ Status MarkForCompilationPass::RunImpl( // Names for each cluster. std::unordered_map cluster_names; + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph, + options.flib_def); + } + // Mark clusters for compilation that: // * are placed on a device that requires compilation (an XlaDevice), // * are explicitly marked for compilation (_XlaCompile=true), or @@ -689,7 +1024,7 @@ Status MarkForCompilationPass::RunImpl( string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; @@ -700,6 +1035,9 @@ Status MarkForCompilationPass::RunImpl( dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); } + + VLogClusteringSummary(*graph); + return Status::OK(); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e9acbfb19e42cb43cb0b986c438a569de29b2ebc..f1137af3c1e8539fda318d88d2c5b5187953ccab 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; - // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass - // unconditionally, call RunImpl() directly. - // is_compilable_fn, if set, is a predicate that must be true for a node to - // be compiled. + private: Status RunImpl(const GraphOptimizationPassOptions& options, const std::function& is_compilable_fn = {}); + + friend class MarkForCompilationPassTestHelper; }; // Returns true iff 'ndef' is a call to a function that is compilable. A // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 2c5f4fb774fcab082c0d0d316cdc6757cacc1e96..4f9145b4799d5fbaaae2bafd47edec7fa6e463a3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/memory/memory.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" @@ -26,11 +29,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -39,27 +42,6 @@ namespace { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -Status MarkForCompilation(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def) { - // Assign all nodes to the CPU device. - static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; - for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); - } - - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunImpl(opt_options); -} - -Status MarkForCompilation(std::unique_ptr* graph) { - FunctionDefLibrary flib; - FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def); -} - std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { @@ -69,9 +51,35 @@ std::unordered_map GetClusters(const Graph& graph) { ids[node->name()] = cluster; } } + + if (VLOG_IS_ON(2)) { + VLOG(2) << "Clusters:"; + for (const auto& p : ids) { + VLOG(2) << " " << p.first << " -> " << p.second; + } + } return ids; } +gtl::FlatMap> GetClusterSets( + const Graph& g, std::vector* cluster_names = nullptr) { + CHECK(cluster_names == nullptr || cluster_names->empty()); + gtl::FlatMap> cluster_sets; + for (const auto& p : GetClusters(g)) { + cluster_sets[p.second].push_back(p.first); + } + for (auto& p : cluster_sets) { + if (cluster_names != nullptr) { + cluster_names->push_back(p.first); + } + std::sort(p.second.begin(), p.second.end()); + } + if (cluster_names != nullptr) { + std::sort(cluster_names->begin(), cluster_names->end()); + } + return cluster_sets; +} + TEST(XlaCompilationTest, Chains) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -88,7 +96,7 @@ TEST(XlaCompilationTest, Chains) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(4, clusters.size()); EXPECT_EQ(clusters["B"], clusters["C"]); @@ -113,7 +121,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -133,7 +141,7 @@ TEST(XlaCompilationTest, CompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); @@ -156,7 +164,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); } @@ -177,7 +185,7 @@ TEST(XlaCompilationTest, HalfSupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_FALSE(clusters.empty()); } @@ -206,7 +214,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); // Everything should be compiled. } @@ -220,7 +228,7 @@ TEST(XlaCompilationTest, FunctionCalls) { {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); FunctionDef noinline = compilable; noinline.mutable_signature()->set_name("NoInlineFn"); - AddAttr("_noinline", bool(true), noinline.mutable_attr()); + AddAttr("_noinline", static_cast(true), noinline.mutable_attr()); FunctionDefLibrary flib; *flib.add_function() = compilable; @@ -241,7 +249,8 @@ TEST(XlaCompilationTest, FunctionCalls) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -272,7 +281,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { ops::UnaryOp("Shape", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -359,7 +368,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -384,7 +393,7 @@ TEST(XlaCompilationTest, Loops) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // Nothing should be compiled. In particular, 'd' and 'c' must not be @@ -392,6 +401,44 @@ TEST(XlaCompilationTest, Loops) { EXPECT_EQ(0, clusters.size()); } +TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor()) + .WithAttr(kXlaScopeAttr, "ScopeA")); + Node* b = ops::UnaryOp( + "Relu", a, + builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); + ops::BinaryOp( + "MatMul", a, b, + builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def(graph->op_registry(), flib); + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, &flib_def, &session_options)); + auto clusters = GetClusters(*graph); + + // The computation is: C = A + relu(A) + // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. + // In this case, the GlobalJitLevel overrides the scopes to cluster while + // ignoring scopes. + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -411,7 +458,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -442,7 +489,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -472,7 +519,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) @@ -483,45 +530,111 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } -REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); -REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); - namespace { +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} -class DummyOp : public XlaOpKernel { - using XlaOpKernel::XlaOpKernel; - void Compile(XlaOpKernelContext* ctx) override {} -}; - -REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); -REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id), + var_handle, value_to_write); + return assign_op.operation.node(); +} +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} } // namespace -TEST(XlaCompilationTest, Resources) { +TEST(XlaCompilationTest, ResourcesClusteringAllowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + FixupSourceAndSinkEdges(root.graph()); std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = - ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); - Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); - // We should not form clusters with resource ops by default. - Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); - Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); - ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } - TF_ASSERT_OK(MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector expected_clustered_nodes = {"AssignmentW", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::vector cluster_names; + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph, &cluster_names); + + ASSERT_EQ(cluster_sets.size(), 2); + + std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", + "ValueToAssignW0"}; + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); + + std::vector expected_clustered_nodes_b = { + "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; + ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); @@ -542,13 +655,13 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { TF_EXPECT_OK(root.ToGraph(graph.get())); - Status status = MarkForCompilation(&graph); + Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.ToString(), - "Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(absl::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { @@ -570,7 +683,7 @@ TEST(XlaCompilationTest, Retval) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -588,7 +701,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { auto r = ops::_Retval(root.WithOpName("R"), c, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -604,7 +717,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { auto r = ops::_Retval(root.WithOpName("R"), b, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -618,7 +731,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), 0.5f); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_EQ(1, GetClusters(*graph).size()); } @@ -629,7 +742,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_TRUE(GetClusters(*graph).empty()); } } @@ -644,7 +757,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -667,7 +780,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -699,7 +812,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -713,5 +826,139 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { EXPECT_EQ(clusters, expected_clusters); } +TEST(XlaCompilationTest, RandomShape) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("shape"), shape_shape, + ops::Const(root.WithOpName("minval"), 1), + ops::Const(root.WithOpName("maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["shape"], ""); +} + +TEST(XlaCompilationTest, RandomShapeWithFunc) { + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/"Stateful_func", /*in_def=*/{}, + /*out_def=*/{"out: int32"}, + /*attr_def*/ + {}, /*node_def=*/ + {FunctionDefHelper::Const("shape_shape", 2), + FunctionDefHelper::Const("minval", 1), + FunctionDefHelper::Const("maxval", 20), + {{"shape"}, + "RandomUniformInt", + {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, + {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, + /*ret_def=*/{{"out", "shape:output:0"}}); + + func.mutable_signature()->set_is_stateful(true); + *flib_def.add_function() = std::move(func); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + NodeDef call_node; + call_node.set_name("fn_call"); + call_node.set_op("Stateful_func"); + Status status; + Node* call = root.graph()->AddNode(call_node, &status); + TF_ASSERT_OK(status); + + Output shape = Output(call, 0); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + auto fld = absl::make_unique(OpRegistry::Global(), + flib_def); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["fn_call"], ""); +} + +TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = + ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, + ops::Const(root.WithOpName("test/minval"), 1), + ops::Const(root.WithOpName("test/maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/shape_rng"], ""); + EXPECT_NE(clusters["test/reshape"], ""); + EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); +} + +TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + Scope root = Scope::NewRootScope().ExitOnError(); + ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, + DT_INT32); + Output zero = ops::Const(root.WithOpName("test/zero"), 0); + ops::TensorArrayWrite tensor_array_write( + root.WithOpName("test/write"), tensor_array.handle, zero, + ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); + Output tensor_array_read = + ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, + zero, tensor_array_write.flow_out, DT_INT32); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), + tensor_array_read); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/read"], ""); + EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..d56d0f8ccfcdab40003be38059228cb255921b64 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + SessionOptions* session_options) { + // Assign all unassigned nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } + } + + // Call AddDevices to register the XLA devices. + // + // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to + // make this more direct, but probably not worth it solely for this test. + std::vector devices; + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); + + auto delete_devices = gtl::MakeCleanup([&] { + for (Device* d : devices) { + delete d; + } + }); + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + opt_options.session_options = session_options; + opt_options.flib_def = flib_def; + MarkForCompilationPass pass; + return pass.RunImpl(opt_options); +} + +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + SessionOptions session_options; + return MarkForCompilation(graph, flib_def, &session_options); +} + +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); + return MarkForCompilation(graph, &flib_def); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..216baaf933dc1f7e694289eea5d23996b595f4d4 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +namespace tensorflow { +class MarkForCompilationPassTestHelper { + public: + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in + // `graph` to the CPU device. To make testing easier, ignores device + // registration, _XlaCompile attributes, input deadness and global jit level. + static Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + SessionOptions* session_options); + + // Like `MarkForCompilation` but creates a default SessionOptions. + static Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // Like `MarkForCompilation` but creates `flib_def` from the op registry. + static Status MarkForCompilation(std::unique_ptr* graph); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8ace628e6b76e011ecddd4d526efc4db9c9237e --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -0,0 +1,458 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { +namespace testing { +namespace matchers { +namespace { + +using impl::NodeMatcherProperties; + +string IndentAllButFirstLine(absl::string_view text) { + std::vector lines = absl::StrSplit(text, '\n'); + for (int i = 1; i < lines.size(); i++) { + lines[i].insert(0, " "); + } + return absl::StrJoin(lines, "\n"); +} + +template +bool CompareTensor(const Tensor& actual, const Tensor& expected, + ::testing::MatchResultListener* listener) { + if (actual.NumElements() != expected.NumElements()) { + if (listener->IsInterested()) { + *listener << "\nwas looking for tensor with " << expected.NumElements() + << " elements, found tensor with " << actual.NumElements() + << " elements"; + return false; + } + } + + for (int64 i = 0, e = actual.NumElements(); i < e; i++) { + if (actual.flat()(i) != expected.flat()(i)) { + *listener << "\nmismatch in constant tensor at index " << i + << " expected = " << expected.flat()(i) + << " actual = " << actual.flat()(i); + return false; + } + } + + return true; +} + +bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, + ::testing::MatchResultListener* listener) { + if (tensor.dtype() != expected_tensor.dtype()) { + if (listener->IsInterested()) { + *listener << "\nexpected tensor of type " + << DataType_Name(expected_tensor.dtype()) + << " but found one of type " << DataType_Name(tensor.dtype()); + return false; + } + } + + switch (tensor.dtype()) { + case DT_FLOAT: + return CompareTensor(tensor, expected_tensor, listener); + case DT_DOUBLE: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT8: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT16: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT32: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT64: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT8: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT16: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT32: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT64: + return CompareTensor(tensor, expected_tensor, listener); + default: + LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. + << DataType_Name(tensor.dtype()); + } +} + +using Input = std::pair; + +struct NodeMatcher : public ::testing::MatcherInterface { + bool MatchAndExplain( + const Node* node, + ::testing::MatchResultListener* listener) const override { + if (op && node->type_string() != *op) { + if (listener->IsInterested()) { + *listener << "\nexpected op " << *op << " but found " + << node->type_string(); + } + return false; + } + + if (assigned_device && node->assigned_device_name() != *assigned_device) { + if (listener->IsInterested()) { + *listener << "\nexpected assigned_device " << *assigned_device + << " but found \"" << node->assigned_device_name() << "\""; + } + return false; + } + + if (name && node->name() != *name) { + if (listener->IsInterested()) { + *listener << "\nexpected name " << *name << " but found " + << node->name(); + } + return false; + } + + if (constant_value) { + const TensorProto* proto = nullptr; + if (!GetNodeAttr(node->def(), "value", &proto).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find \"value\" attribute in node"; + } + return false; + } + + Tensor tensor(proto->dtype()); + if (!tensor.FromProto(*proto)) { + if (listener->IsInterested()) { + *listener << "\ncould not convert TensorProto in \"value\" attribute " + "to Tensor"; + } + return false; + } + + if (!MatchAndExplainTensor(/*tensor=*/tensor, + /*expected_tensor=*/*constant_value, + listener)) { + return false; + } + } + + if (input_matchers) { + if (input_matchers->size() != node->num_inputs()) { + if (listener->IsInterested()) { + *listener << "\nexpected " << input_matchers->size() + << " inputs but node has " << node->num_inputs(); + } + return false; + } + + for (int input_idx = 0, e = input_matchers->size(); input_idx < e; + input_idx++) { + if (!MatchAndExplainInput(node, input_idx, listener)) { + return false; + } + } + } + + std::vector control_deps; + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) { + control_deps.push_back(e->src()); + } + } + + ::testing::StringMatchResultListener inner_listener; + if (control_dep_set && + !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { + if (listener->IsInterested()) { + string explanation = inner_listener.str(); + if (!explanation.empty()) { + explanation = absl::StrCat(", ", explanation, ","); + } + *listener << "ctrl_deps" << explanation << " does not match expected: "; + control_dep_set->DescribeTo(listener->stream()); + } + return false; + } + return true; + } + + void DescribeTo(::std::ostream* os) const override { + std::vector predicates; + + if (name) { + predicates.push_back(absl::StrCat("name: ", *name)); + } + + if (op) { + predicates.push_back(absl::StrCat("op: ", *op)); + } + + if (assigned_device) { + predicates.push_back(absl::StrCat("assigned device: ", *assigned_device)); + } + + bool printed_something = !predicates.empty(); + + *os << absl::StrJoin(predicates, ", "); + + if (constant_value) { + printed_something = true; + *os << "constant value: " << constant_value->DebugString(); + } + + if (input_matchers) { + if (!input_matchers->empty()) { + printed_something = true; + *os << " with " << (input_matchers->size() == 1 ? "only " : "") + << "input" << (input_matchers->size() == 1 ? "" : "s") << " "; + } + + if (input_matchers->size() == 1) { + ::std::stringstream ss; + input_matchers->front().DescribeTo(&ss); + printed_something = true; + *os << "matching " << ss.str(); + } else { + int edge_idx = 0; + for (const ::testing::Matcher& matcher : (*input_matchers)) { + *os << "\n [" << edge_idx << "] matching ("; + ::std::stringstream ss; + matcher.DescribeTo(&ss); + printed_something = true; + *os << IndentAllButFirstLine(ss.str()); + *os << ")"; + edge_idx++; + } + } + } + + if (control_dep_set) { + printed_something = true; + *os << " and control deps "; + control_dep_set->DescribeTo(os); + } + + if (!printed_something) { + *os << "is any node"; + } + } + + bool MatchAndExplainInput(const Node* node, int input_idx, + ::testing::MatchResultListener* listener) const { + const Edge* edge; + if (!node->input_edge(input_idx, &edge).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find incoming edge for input " << input_idx; + } + return false; + } + + ::testing::StringMatchResultListener inner_listener; + Input input = {edge->src(), edge->src_output()}; + if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { + return true; + } + + if (listener->IsInterested()) { + *listener << "\ninput " << input_idx << " does not match expected:\n"; + (*input_matchers)[input_idx].DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << ", " << explanation; + } + } + return false; + } + + absl::optional op; + absl::optional name; + absl::optional assigned_device; + absl::optional constant_value; + absl::optional>> input_matchers; + absl::optional<::testing::Matcher>> + control_dep_set; +}; + +// Matches a dst and dst_output on an input edge. Today we only use this with +// dst_output=0 but we will eventually need to support multi-output operations. +class InputMatcher : public ::testing::MatcherInterface { + public: + InputMatcher(::testing::Matcher src_matcher, int src_output) + : src_matcher_(std::move(src_matcher)), src_output_(src_output) {} + + bool MatchAndExplain( + Input input, ::testing::MatchResultListener* listener) const override { + ::testing::StringMatchResultListener inner_listener; + if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) { + if (listener->IsInterested()) { + *listener << "\nsource does not match expected "; + src_matcher_.DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << "\n\t" << explanation; + } + } + return false; + } + if (input.second != src_output_) { + if (listener->IsInterested()) { + *listener << "\nexpected output slot to be " << src_output_ + << " but found " << input.second; + } + return false; + } + + return true; + } + + void DescribeTo(::std::ostream* os) const override { + if (src_output_) { + *os << "output slot: " << src_output_ << ", source: ("; + } + + src_matcher_.DescribeTo(os); + + if (src_output_) { + *os << ")"; + } + } + + private: + ::testing::Matcher src_matcher_; + int src_output_; +}; + +std::vector<::testing::Matcher> NodeMatchersToInputMatchers( + absl::Span> node_matchers) { + std::vector<::testing::Matcher> result; + absl::c_transform(node_matchers, std::back_inserter(result), + [](::testing::Matcher n) { + return ::testing::MakeMatcher(new InputMatcher(n, 0)); + }); + return result; +} +} // namespace + +::testing::Matcher impl::NodeWith( + absl::Span props) { + NodeMatcher* matcher = new NodeMatcher(); + for (const NodeMatcherProperties& prop : props) { + if (prop.name()) { + DCHECK(!matcher->name); + matcher->name = prop.name(); + } + + if (prop.op()) { + DCHECK(!matcher->op); + matcher->op = prop.op(); + } + + if (prop.constant_value()) { + DCHECK(!matcher->constant_value); + matcher->constant_value = prop.constant_value(); + } + + if (prop.assigned_device()) { + DCHECK(!matcher->assigned_device); + matcher->assigned_device = prop.assigned_device(); + } + + if (prop.input_nodes()) { + DCHECK(!matcher->input_matchers); + matcher->input_matchers = + NodeMatchersToInputMatchers(*prop.input_nodes()); + } + + if (prop.control_deps()) { + DCHECK(!matcher->control_dep_set); + matcher->control_dep_set = + ::testing::UnorderedElementsAreArray(*prop.control_deps()); + } + } + + return ::testing::MakeMatcher(matcher); +} + +impl::NodeMatcherProperties Name(string name) { + impl::NodeMatcherProperties props; + props.set_name(std::move(name)); + return props; +} + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op) { + impl::NodeMatcherProperties props; + props.set_op(std::move(op)); + return props; +} + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device) { + impl::NodeMatcherProperties props; + props.set_assigned_device(std::move(assigned_device)); + return props; +} + +impl::NodeMatcherProperties impl::Inputs( + absl::Span> inputs) { + std::vector<::testing::Matcher> inputs_vector; + absl::c_copy(inputs, std::back_inserter(inputs_vector)); + + impl::NodeMatcherProperties props; + props.set_input_nodes(std::move(inputs_vector)); + return props; +} + +impl::NodeMatcherProperties impl::CtrlDeps( + absl::Span> control_deps) { + std::vector<::testing::Matcher> control_deps_vector; + absl::c_copy(control_deps, std::back_inserter(control_deps_vector)); + + impl::NodeMatcherProperties props; + props.set_control_deps(std::move(control_deps_vector)); + return props; +} + +NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val) { + TF_CHECK_OK(val.status); + NodeMatcherProperties props; + props.set_constant_value(val.tensor); + return props; +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val) { + return NodeWith(ConstantValue(val)); +} +} // namespace matchers + +Node* FindNodeByName(Graph* g, absl::string_view name) { + for (Node* n : g->nodes()) { + if (n->name() == name) { + return n; + } + } + + return nullptr; +} +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h new file mode 100644 index 0000000000000000000000000000000000000000..0437a7e95c1eb3bdcdbe24a440dd90a5943c0894 --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.h @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a set of matchers for tensorflow nodes. +// +// Example usage: +// +// tensorflow::Node* node = ...; +// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), +// Inputs(NodeWith(Name("input"))))) +// +// Matchable node properties (the expressions that go inside NodeWith(...)) +// are: +// +// - Name(string): matches the node name exactly. We will probably need to +// have this take a string matcher soon in the future. +// +// - Op(string): matches the op exactly. +// +// - AssignedDevice(string): matches the assigned device exactly. +// +// - Inputs(): matches the list of non-control inputs to the node +// exactly (i.e. does not match a suffix or a prefix). +// +// - CtrlDeps(): matches the list of control dependences on the +// node exactly but in any order. +// +// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node +// with the constant value `init`. Implies Op("Const"). +// +// Node properties may not be repeated in a single NodeWith(...) matcher. +// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue +// implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). + +#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ +#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace testing { +namespace matchers { + +namespace impl { + +// ----------------------------------------------------------------------------- +// Implementation details. + +// Properties that we match on for a particular Node. If a particular property +// is nullopt then any value for it is allowed. +class NodeMatcherProperties { + public: + using NodeSeqMatcher = std::vector<::testing::Matcher>; + + const absl::optional& name() const { return name_; } + const absl::optional& op() const { return op_; } + const absl::optional& assigned_device() const { + return assigned_device_; + } + const absl::optional& constant_value() const { + return constant_value_; + } + const absl::optional& input_nodes() const { + return input_nodes_; + } + const absl::optional& control_deps() const { + return control_deps_; + } + + void set_name(string name) { + DCHECK(IsEmpty()); + name_ = std::move(name); + } + + void set_op(string op) { + DCHECK(IsEmpty()); + op_ = std::move(op); + } + + void set_assigned_device(string assigned_device) { + DCHECK(IsEmpty()); + assigned_device_ = std::move(assigned_device); + } + + void set_constant_value(Tensor constant_value) { + DCHECK(IsEmpty()); + constant_value_ = std::move(constant_value); + op_ = "Const"; + } + + void set_input_nodes(NodeSeqMatcher input_nodes) { + DCHECK(IsEmpty()); + input_nodes_ = std::move(input_nodes); + } + + void set_control_deps(NodeSeqMatcher control_deps) { + DCHECK(IsEmpty()); + control_deps_ = std::move(control_deps); + } + + bool IsEmpty() const { + return !name().has_value() && !op().has_value() && + !input_nodes().has_value() && !control_deps().has_value(); + } + + private: + absl::optional name_; + absl::optional op_; + absl::optional assigned_device_; + absl::optional constant_value_; + absl::optional input_nodes_; + absl::optional control_deps_; +}; + +::testing::Matcher NodeWith( + absl::Span props); + +impl::NodeMatcherProperties Inputs( + absl::Span> inputs); + +impl::NodeMatcherProperties CtrlDeps( + absl::Span> control_deps); +} // namespace impl + +// ----------------------------------------------------------------------------- +// Public interface. + +// Matches a node with name `name`. +impl::NodeMatcherProperties Name(string name); + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op); + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device); + +// Matches a node with inputs `inputs`. +// +// `inputs` are ordered; `inputs`[i] must match input i. +template +impl::NodeMatcherProperties Inputs(Ts... inputs) { + return impl::Inputs({inputs...}); +} + +// Matches a node with control dependences `control_deps`. +// +// `control_deps` are unordered and will match the control deps of a node in any +// order. +template +impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { + return impl::CtrlDeps({control_deps...}); +} + +// Matches a constant node with value `val`. +impl::NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val); + +// The main gmock matcher. See file comment for example usage. +template +::testing::Matcher NodeWith(Ts... args) { + std::array array = {args...}; + return impl::NodeWith(array); +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val); +} // namespace matchers + +// If `g` has a node named `name` returns it, otherwise returns null. +Node* FindNodeByName(Graph* g, absl::string_view name); +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..93a8994307b38ac240c22d0a18268638ac7620ae --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" + +namespace tensorflow { +namespace testing { +namespace { + +using ::testing::_; + +using testing::matchers::AssignedDevice; +using testing::matchers::ConstantValue; +using testing::matchers::CtrlDeps; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; + +template +string Explain(const T& t, const M& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(t, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(t, &listener)); + return listener.str(); +} + +TEST(NodeMatchers, CheckAgainstConstant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Name("placeholder"), Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Inputs())); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"), Inputs())); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))), + "\nexpected op Add but found Placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))), + "\nexpected name add but found placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))), + "\nexpected 1 inputs but node has 0"); +} + +TEST(NodeMatchers, CheckAgainstBinary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b); + + EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"), + Inputs(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + + EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())), + "\nexpected 0 inputs but node has 2"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))), + "\ninput 0 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_a"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))), + "\ninput 1 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_b"); +} + +TEST(NodeMatchers, CheckControlDependence) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output placeholder_c = + ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT); + Output placeholder_d = + ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT); + + root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node()); + root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node()); + + EXPECT_THAT(placeholder_c.node(), + NodeWith(Name("placeholder_c"), + CtrlDeps(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + EXPECT_THAT(placeholder_d.node(), + NodeWith(Name("placeholder_d"), CtrlDeps())); + + EXPECT_EQ( + Explain(placeholder_c.node(), NodeWith(CtrlDeps())), + "ctrl_deps, which has 2 elements, does not match expected: is empty"); + EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))), + "ctrl_deps does not match expected: has 1 element and that element " + "is any node"); +} + +TEST(NodeMatchers, ConstVaulue) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + Output const_0d = ops::Const(root.WithOpName("const_0d"), 42); + + Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}}); + + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42))); + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d"))); + + EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))), + "\nexpected op Const but found Placeholder"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue(43))), + "\nmismatch in constant tensor at index 0 expected = 43 actual = 42"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))), + "\nwas looking for tensor with 4 elements, found tensor with 1 elements"); + EXPECT_EQ( + Explain(const_2d.node(), NodeWith(ConstantValue(42))), + "\nwas looking for tensor with 1 elements, found tensor with 4 elements"); +} + +TEST(NodeMatchers, AssignedDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + + Output assigned_add = + ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b); + assigned_add.node()->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:CPU:0"); + + Output unassigned_add = + ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b); + + EXPECT_THAT( + assigned_add.node(), + NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0"))); + EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice(""))); + + EXPECT_EQ(Explain(unassigned_add.node(), + NodeWith(AssignedDevice( + "/job:localhost/replica:0/task:0/device:CPU:0"))), + "\nexpected assigned_device " + "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\""); +} + +} // namespace +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index c9e46bc1475aed0e35a48765ad70eef4362e8281..f72224545b25bc7100e0b6788e6fbf0a7ca63dad 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -4,6 +4,8 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + cc_library( name = "xla_ops", srcs = ["xla_ops.cc"], @@ -11,9 +13,8 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - deps = ["//tensorflow/core:framework"], - alwayslink = 1, +tf_gen_op_wrapper_py( + name = "xla_ops_wrapper_py", + out = "xla_ops.py", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index f2473d98ffd5dae55983e601b8d2d65af6a6d54c..bcd1a29b1ff789b5674a21ff66cc6d23a809afc5 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -32,4 +36,58 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); +REGISTER_OP("XlaClusterOutput") + .Input("input: T") + // Note: when replication is supported, this op will have N outputs. + .Output("outputs: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an XLA computation to other " + "consumer graph nodes."); + +REGISTER_OP("_XlaCompile") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Input("resources: Nresources * resource") + .Attr("Nresources: int >= 0") + .Output("key: string") + .Output("compilation_successful: bool") + .Attr("function: func") + // The compilation cache is stateful. + .SetIsStateful() + .Doc(R"(XLA Compile Op. For use by the XLA JIT only. + +Compiles a TensorFlow function into an XLA LocalExecutable and returns a key +that _XlaRun can use to look up the LocalExecutable and execute it. + +key: A key that can be used to look up the local executable compiled by the + node and associated metadata. + +compilation_successful: True iff the compilation was successful. Always true +for now. +)"); + +REGISTER_OP("_XlaRun") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Input("key: string") + // XLA random-number generation ops are stateful. + // TODO(phawkins): create stateful and non-stateful variants of _XlaRun. + .SetIsStateful() + .Doc(R"(XLA Run Op. For use by the XLA JIT only. + +Executes a TensorFlow function previously compiled into a LocalExecutable by an +_XlaCompile op. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..10fc9e85d927ffe2416d6d9e6dfd24b286fbf1a0 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -0,0 +1,322 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace { +Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, + absl::Span post_order) { + // Find nodes that have at least one user outside their cluster that expects + // hostmem output. These nodes should be cloned to outside the cluster to + // avoid the device-host copy we'd otherwise need. + + MemoryTypeVector input_mtypes, output_mtypes; + + for (Node* n : post_order) { + absl::optional from_cluster = GetXlaClusterForNode(*n); + if (!from_cluster) { + continue; + } + + // We assume the only XLA-auto-clusterable operations with side effects are + // resource variable updates. We can't execute these twice. + if (HasResourceInputOrOutput(*n)) { + continue; + } + + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + n->def(), &input_mtypes, + &output_mtypes)); + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + + if (e->IsControlEdge()) { + continue; + } + + bool edge_incurs_extra_device_to_host_copy; + if (output_mtypes[e->src_output()] == DEVICE_MEMORY) { + // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then + // keep the node clustered -- XLA will also produce the output in device + // memory and we will get some benefit from clustering. + edge_incurs_extra_device_to_host_copy = false; + } else { + MemoryTypeVector dst_input_mtypes, dst_output_mtypes; + DeviceType dst_device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + dst->def(), &dst_input_mtypes, + &dst_output_mtypes)); + edge_incurs_extra_device_to_host_copy = + dst_input_mtypes[e->dst_input()] == HOST_MEMORY; + } + + if (!edge_incurs_extra_device_to_host_copy) { + continue; + } + + // Check if `dst` is in a different cluster, unclustered, or about to be + // partially declustered (here we rely on the post-order traversal order). + // If yes, decluster `n` to avoid the device-to-host memcpy. + absl::optional dst_cluster = + result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst); + if (from_cluster != dst_cluster) { + CHECK(result->insert(n).second); + break; + } + } + } + return Status::OK(); +} + +Status PartiallyDeclusterNode(Graph* graph, Node* n) { + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + absl::InlinedVector out_edges_to_clone; + for (const Edge* out_edge : n->out_edges()) { + if (out_edge->IsControlEdge()) { + continue; + } + + Node* dst = out_edge->dst(); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*dst); + if (dst_cluster_name != cluster_name) { + out_edges_to_clone.push_back(out_edge); + } + } + + CHECK(!out_edges_to_clone.empty()) << n->DebugString(); + + NodeDef ndef = n->def(); + ndef.set_name(absl::StrCat(n->name(), "/declustered")); + RemoveFromXlaCluster(&ndef); + Status s; + Node* cloned_node = graph->AddNode(ndef, &s); + cloned_node->set_assigned_device_name(n->assigned_device_name()); + TF_RETURN_IF_ERROR(s); + + for (const Edge* in_edge : n->in_edges()) { + graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node, + in_edge->dst_input()); + } + + for (const Edge* out_edge_to_clone : out_edges_to_clone) { + graph->AddEdge(cloned_node, out_edge_to_clone->src_output(), + out_edge_to_clone->dst(), out_edge_to_clone->dst_input()); + graph->RemoveEdge(out_edge_to_clone); + } + + return Status::OK(); +} + +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } + +// Clones nodes to outside their cluster to avoid device-to-host copies. For +// instance, converts this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { + // When deciding whether to decluster a particular node, we base our decision + // on if we've decided that some of its consumers have to be declustered too. + // Iterating the graph in post-order guarantees that consumers have been + // visited before producers. + std::vector post_order; + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + + gtl::FlatSet nodes_to_partially_decluster; + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); + + if (VLOG_IS_ON(3)) { + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + VLOG(3) << n->DebugString(); + } + } + } + + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n)); + } + } + + nodes_to_partially_decluster.clear(); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); + CHECK(nodes_to_partially_decluster.empty()); + + return Status::OK(); +} + +bool IsIntraClusterEdge(const Edge& edge) { + absl::optional src_cluster_name = + GetXlaClusterForNode(*edge.src()); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*edge.dst()); + return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; +} + +Status MustCompileNode(const Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *result = false; + } else { + *result = registration->requires_compilation; + } + + return Status::OK(); +} + +// Declusters nodes to reduce the number of times we think we need to recompile +// a TensorFlow graph. +// +// Abstractly, if we have a cluster of this form: +// +// x0 = arg0 +// x1 = arg1 +// ... +// shape = f(x0, x1, ...) +// result = Reshape(input=, new_shape=shape) +// +// then pulling `f` out of the cluster may reduce the number of compilations and +// will never increase the number of compilations. +// +// We may reduce the number of compilations if f is many to one. For instance +// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different +// compilations if f is in the cluster but only one compilation if f is outside +// the cluster. +// +// Declustering f will increase the number of compilations only if f is a +// one-to-many "function" i.e. isn't a function at all. RNG is one possible +// example, depending on how we look at it. But we never create clusters where +// such f's would be marked as must-be-constant. +// +// We assume here that the extra repeated (repeated compared to a clustered f +// where it will always be constant folded) host-side computation of f does not +// regress performance in any significant manner. We will have to revisit this +// algorith with a more complex cost model if this assumption turns out to be +// incorrect. +Status DeclusterNodesToReduceRecompilations(Graph* graph) { + std::vector compile_time_const_nodes(graph->num_node_ids()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + + std::vector rpo; + GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + for (Node* n : rpo) { + if (!compile_time_const_nodes[n->id()]) { + continue; + } + + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + bool node_on_cluster_edge = + absl::c_all_of(n->in_edges(), [&](const Edge* e) { + absl::optional incoming_cluster = + GetXlaClusterForNode(*e->src()); + return !incoming_cluster || *incoming_cluster != cluster_name; + }); + + // We don't want to decluster F in a graph like + // + // Input -> OP -> Shape -> F -> Reshape + // + // Doing so will break up the cluster. Even if we were okay with breaking + // up the cluster we will at least have to relabel the two clusters to have + // different cluster names. + // + // We may want to revisit this in the future: we may have cases where OP is + // a small computation that does not benefit from XLA while XLA can optimize + // everything that follows the Reshape. In these cases it may be wise to + // remove Input, OP, Shape and F from the cluster, if F is a many-to-one + // function. + // + // Note that we do do the right thing for graphs like: + // + // Input -> F0 -> F1 -> Reshape + // + // Since we iterate in RPO, we'll first encounter F0, decluster it, then + // encounter F1, decluster it and so on. + if (node_on_cluster_edge) { + bool must_compile_node; + TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node)); + if (!must_compile_node) { + VLOG(3) << "Declustering must-be-constant node " << n->name(); + RemoveFromXlaCluster(n); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); + TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/core/util/status_util.h b/tensorflow/compiler/jit/partially_decluster_pass.h similarity index 51% rename from tensorflow/core/util/status_util.h rename to tensorflow/compiler/jit/partially_decluster_pass.h index ea92f61dce0b4e3a9470e25d96dbb599954ea46f..cfc4ddb5630bec91d6942c983ce1efae3a735c43 100644 --- a/tensorflow/core/util/status_util.h +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -13,24 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ -#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { -// Creates a tag to be used in an exception error message. This can be parsed by -// the Python layer and replaced with information about the node. +// Clones or moves nodes from within a cluster to outside the cluster if +// profitable. There are two reasons why we do this: // -// For example, error_format_tag(node, "${file}") returns -// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as -// e.g. "file/where/node/was/created.py". -inline string error_format_tag(const Node& node, const string& format) { - return strings::StrCat("^^node:", node.name(), ":", format, "^^"); -} +// - Reducing device-to-host copies. +// - Reducing the number of XLA recompilations. +class PartiallyDeclusterPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; } // namespace tensorflow -#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ +#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0feb73a89e7050e8c413e5a733da1d87775b0ba3 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -0,0 +1,409 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" + +#include "absl/memory/memory.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +REGISTER_OP("FakeNullary").Output("out: float"); + +REGISTER_OP("FakeBinary") + .Input("host_in: float") + .Input("device_in: float") + .Output("host_out: float") + .Output("device_out: float"); + +REGISTER_OP("FakeResourceVar").Output("out: resource"); + +REGISTER_OP("FakeResourceUpdate") + .Input("in: resource") + .Output("out: resource") + .Output("something_else: float"); + +class FakeBinaryOp : public OpKernel { + public: + explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +class FakeResourceUpdateOp : public OpKernel { + public: + explicit FakeResourceUpdateOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +REGISTER_KERNEL_BUILDER(Name("FakeBinary") + .Device(DEVICE_CPU) + .HostMemory("host_in") + .HostMemory("host_out"), + FakeBinaryOp); + +REGISTER_KERNEL_BUILDER( + Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"), + FakeResourceUpdateOp); + +Status PartiallyDecluster(std::unique_ptr* graph) { + FixupSourceAndSinkEdges(graph->get()); + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + PartiallyDeclusterPass pass; + return pass.Run(opt_options); +} + +Node* FindNodeByName(const Graph& graph, const string& name) { + for (Node* node : graph.nodes()) { + if (node->name() == name) { + return node; + } + } + return nullptr; +} + +bool GetInputsForNode(const Graph& graph, const string& node_name, + std::vector* inputs) { + const Node* node = FindNodeByName(graph, node_name); + if (node == nullptr) { + return false; + } + for (const Edge* e : node->in_edges()) { + inputs->push_back(e->src()); + } + std::sort(inputs->begin(), inputs->end(), NodeComparatorName()); + return true; +} + +TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + std::vector clustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer", + &clustered_consumer_inputs)); + ASSERT_EQ(clustered_consumer_inputs.size(), 2); + EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DifferentClusters) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + // The first input is hostmem and the second input is devicemem. + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", input, clustered_producer, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* resource_var = ops::SourceOp("FakeResourceVar", + builder.opts().WithName("ResourceVar")); + Node* clustered_producer = + ops::UnaryOp("FakeResourceUpdate", resource_var, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer0")); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName("ClusteredProducer1")); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs, declustered_producer_1_inputs; + + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer1/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered", + &declustered_producer_1_inputs)); + ASSERT_EQ(declustered_producer_1_inputs.size(), 2); + EXPECT_EQ(declustered_producer_1_inputs[0]->name(), + "ClusteredProducer0/declustered"); + EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); +} + +void AddToCluster(absl::Span nodes, + absl::string_view cluster_name) { + for (Node* n : nodes) { + n->AddAttr(kXlaClusterAttr, string(cluster_name)); + } +} + +TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt); +} + +TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT, + ops::Placeholder::Attrs{}); + Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b); + Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul); + + Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()}, + "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + +TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({reshape.node()}, "cluster_0"); + AddToCluster({shape.node()}, "cluster_1"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + // This is needed to register the XLA_GPU device. + std::vector devices; + TF_ASSERT_OK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); + + // Scope::ToGraph loses the assigned device name since it goes through + // GraphDef/NodeDef which does not have a field for the assigned device name. + Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + n->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:XLA_GPU:0"); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); + + for (Device* d : devices) { + delete d; + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..56e35c0059124015266ffabdf583c8724c8e0908 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -0,0 +1,336 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ALGORITHM OVERVIEW +// ================== +// +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// computes the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write +// dependencies. +// +// Specifically the result computed by this analysis contains the edge {W, R} +// iff all of these hold true: +// +// - In the graph (g - {edges from NextIteration to Merge}) there is a path +// from W to R. +// - IsEdgeSafe(W, R) == False [defined below] +// - W != R (note: some resource operations both read from and write to +// resource variables). +// +// The result is incorrect around loops because we ignore edges from +// NextIteration to Merge, but that should be fine because we don't cluster +// these edges. For instance, in: +// +// Init -----> Merge <-------+ +// | | +// v | +// Read | +// | | +// v | +// Write | +// | | +// v | +// NextIteration --+ +// +// we won't put (Read, Write) in the returned set. This is fine if +// auto-clustering can only cluster the Read->Write edge, but it is a problem if +// it clusters the Write->NextIteration->Merge->Read edges instead. The same +// problem is present for the functional version of the loop above. We rely on +// auto-clustering to not cluster control flow edges like NextIteration->Merge. +// This is enough to avoid the explicit-control-flow problem shown above. One +// way to think about this is that we only care about cases where two nodes, A +// and B, would normally have been put in the same cluster but cannot legally be +// in the same cluster because of resourcevar-dependencies. If A and B would +// normally have been put in the same cluster then all paths between A and B +// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo +// there could not have been a NextIteration->Merge edge between A and B since +// we don't cluster these edges. +// +// We also rely on auto-clustering to not cluster functional control flow nodes +// that contain resource operations. +// +// IMPLEMENTATION +// -------------- +// +// We traverse the graph minus backedges in reverse post order, mapping each +// node to the set of resource operation reaching that node. Since we visit +// producers before consumers, we can construct the set of reaching operations +// by taking the union of the operations reaching the input nodes. These +// "reaching resource operations" can then be used to create the pairs of +// incompatible nodes using `IsEdgeSafe`. + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { +// Returns true if `n` may call a function. +Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, + bool* out_result) { + if (flib_def->Contains(n.type_string())) { + *out_result = true; + } else { + *out_result = + std::any_of(n.def().attr().begin(), n.def().attr().end(), + [](const std::pair& name_attr_pair) { + return name_attr_pair.second.has_func(); + }); + } + + return Status::OK(); +} + +// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is +// not a resource operation recognized by XLA then sets `out_resource_op_kind` +// to nullopt. +Status XlaResourceOpKindForNode( + const Node& n, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + absl::optional* out_resource_op_kind) { + bool should_ignore = false; + if (resource_ops_to_ignore) { + TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); + } + if (should_ignore) { + *out_resource_op_kind = absl::nullopt; + return Status::OK(); + } + + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); + if (op_info) { + *out_resource_op_kind = op_info->kind(); + return Status::OK(); + } + + // We conservatively assume that functions will both read and write resource + // variables. In the future we may consider doing some form of + // inter-procedural analysis. + bool may_call_function; + TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); + if (may_call_function) { + *out_resource_op_kind = XlaResourceOpKind::kReadWrite; + } else { + *out_resource_op_kind = absl::nullopt; + } + + return Status::OK(); +} + +// Returns true if a control or data dependence from a TensorFlow operation of +// resource op kind `from` to a TensorFlow operation of resource op kind `to` +// can be represented by an XLA cluster and needs no special handling around +// auto-jit. +bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { + // XLA clusters forces all reads to happen before all writes, which means the + // kinds of edges it can faithfully represent are: Read->Write, Read->Modify, + // Modify->Write, Read->Read, Write->Write. + // + // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write + // dependencies. + return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite; +} + +using ResourceOp = std::pair; + +string ResourceOpToString(const ResourceOp& resource_op) { + return absl::StrCat( + resource_op.first, ": ", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); +} + +// A copy-on-write set used to store the set of ResourceOps reaching a node in a +// TensorFlow graph. +// +// TODO(sanjoy): It may be useful to pull this out into its own header at some +// point. +class ResourceOpSet { + private: + using Impl = gtl::FlatSet; + + public: + ResourceOpSet() = default; + + // Adds all ResourceOp s in `other` to this set. + void Add(const ResourceOpSet& other) { + CHECK(!frozen_); + if (other.impl_ == impl_) { + other.frozen_ = true; + return; + } + + if (!impl_) { + other.frozen_ = true; + impl_ = other.impl_; + return; + } + + for (ResourceOp resource_op : other) { + Add(resource_op); + } + } + + void Add(const ResourceOp& resource_op) { + CHECK(!frozen_); + if (!IsCopy() && Contains(resource_op)) { + // We can avoid the copy if the item we want to insert already exists. + return; + } + + EnsureIsCopied(); + impl_->insert(resource_op); + } + + Impl::const_iterator begin() const { + return impl_ ? impl_->begin() : GetEmptyImpl()->begin(); + } + + Impl::const_iterator end() const { + return impl_ ? impl_->end() : GetEmptyImpl()->end(); + } + + bool Contains(const ResourceOp& resource_op) const { + return impl_ != nullptr && impl_->count(resource_op); + } + + private: + bool IsCopy() const { return storage_ != nullptr; } + + void EnsureIsCopied() { + if (storage_ == nullptr) { + storage_ = absl::make_unique(); + for (ResourceOp op : *this) { + storage_->insert(op); + } + impl_ = storage_.get(); + } + } + + static Impl* GetEmptyImpl() { + static Impl* empty_impl = new Impl; + return empty_impl; + } + + Impl* impl_ = nullptr; + std::unique_ptr storage_; + + // frozen_ is true if there is another set pointing to this set's impl_. We + // can no longer add elements to this set in that case since the sets pointing + // to this set expect the contents of this set to be stable. + mutable bool frozen_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet); +}; + +string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { + std::vector elements_debug_string; + std::transform(resource_op_set.begin(), resource_op_set.end(), + std::back_inserter(elements_debug_string), ResourceOpToString); + return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); +} + +string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { + return absl::StrCat( + "[", n.name(), ": ", n.type_string(), "(", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); +} +} // namespace + +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + std::vector>* result) { + CHECK(result->empty()); + + std::vector rpo; + GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + auto resource_op_set_for_node = + absl::make_unique(g.num_node_ids()); + + const bool vlog = VLOG_IS_ON(2); + + for (Node* n : rpo) { + absl::optional op_kind; + TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( + *n, flib_def, resource_ops_to_ignore, &op_kind)); + + ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; + + // Merge the reaching resource operations for all the incoming edges to + // create the set of all possible resource ops reaching `n`. + for (const Edge* e : n->in_edges()) { + if (n->IsMerge() && e->src()->IsNextIteration()) { + // Ignore back-edges (see file comment). + continue; + } + + const ResourceOpSet& incoming_op_set = + resource_op_set_for_node[e->src()->id()]; + resource_op_set->Add(incoming_op_set); + } + + // Add to the "incompatible resource ops" set if necessary. + if (op_kind) { + for (ResourceOp incoming_op : *resource_op_set) { + if (IsEdgeSafe(incoming_op.second, *op_kind)) { + continue; + } + + if (vlog) { + VLOG(2) << "Unsafe edge: " + << NodeToString(*g.FindNodeId(incoming_op.first), + incoming_op.second) + << " -> " << NodeToString(*n, *op_kind); + } + result->push_back({incoming_op.first, n->id()}); + } + + resource_op_set->Add({n->id(), *op_kind}); + } + + if (vlog) { + VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set); + } + } + + std::sort(result->begin(), result->end()); + CHECK(std::unique(result->begin(), result->end()) == result->end()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..ae8cfeecad9b9cd631db3e9865bb3c3ff28a2e48 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// returns the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// The restrictions are not transitive: it is fine to put A and C in the same +// cluster even if the returned set contains (A,B) and (B,C). +// +// In other words, if these pairs are seen as edges in an undirected graph of +// the nodes in `g` then auto-clustering is at least as constrained as the graph +// coloring problem on this graph. +// +// +// For instance if we auto-cluster all operations in this TensorFlow graph: +// +// ReadVariablepOp0 -> ReadVariableOp1 +// | +// v +// AssignVariableOp0 -> AssignVariableOp1 +// +// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the +// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for +// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads +// all the resource variables when the cluster starts executing without any +// particular ordering between them; same holds for the AssignVariableOp0 -> +// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will +// be respected by XlaLaunchOp though because all reads happen before all +// writes. +// +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// back-edges (i.e. the edges from NextIteration to Merge). +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// functional control flow nodes containing resource operations. +// +// If `resource_ops_to_ignore` is set then nodes for which it returns true are +// ignored (we pretend these nodes are not resource operations). +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + std::vector>* result); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e54b547abcfea698fe79e81dce547ea7858ff829 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -0,0 +1,540 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +Node* MakeModify(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f); + ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id), + var_handle, value_to_write); + return assign_add_op.operation.node(); +} + +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} + +Status ComputeIncompatiblePairs(Graph* g, + std::vector>* result) { + FixupSourceAndSinkEdges(g); + return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {}, + result); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) { + Scope root = Scope::NewRootScope().ExitOnError(); + + MakeRead(root, "R"); + MakeWrite(root, "W"); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(read, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(modify, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair modify_read_pair = {modify->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(modify, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair modify_write_pair = {modify->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_modify_pair = {write->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, modify); + root.graph()->AddControlEdge(modify, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 2); + std::pair modify_write_pair = {modify->id(), write->id()}; + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + root.graph()->AddControlEdge(modify, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair write_modify_pair = {write->id(), modify->id()}; + std::pair modify_read_pair = {modify->id(), read->id()}; + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair write_modify_pair = {write->id(), modify->id()}; + std::pair write_read_pair = {write->id(), read->id()}; + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, + Status* status) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + return graph->AddNode(call_node, status); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair call_read_edge = {call->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], call_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(read, call); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair read_call_edge = {read->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], read_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair call_write_edge = {call->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], call_write_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(write, call); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_call_edge = {write->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], write_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(symbolic_gradient, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair symbolic_gradient_read_edge = {symbolic_gradient->id(), + read->id()}; + EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(write, symbolic_gradient); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_symbolic_gradient_edge = {write->id(), + symbolic_gradient->id()}; + EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 5); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + std::pair write_0_write_1_pair = {write_0->id(), write_1->id()}; + std::pair read_0_read_1_pair = {read_0->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + root.graph()->AddControlEdge(write_1, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT); + Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL); + Output enter_value = + ops::internal::Enter(root.WithOpName("enter"), init_value, "fr"); + ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName("exit"), iv.output); + Output next_iteration = + ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true); + TF_ASSERT_OK( + root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)); + + Node* write = MakeWrite(root, "W"); + Node* read = MakeRead(root, "R"); + + root.graph()->AddControlEdge(iv.output.node(), write); + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, next_iteration.node()); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index a5628b12a27c9ed052e22c784517a07f2c1c059a..f85121ca27ad3da918315f93b28e9000dfd65e67 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -51,8 +53,8 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, }; string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); + absl::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); path.resize(path_size); for (int32 node_id : path) { string ascii_art; @@ -63,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } else { ascii_art = "+-- "; } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + absl::StrAppend(&description, ascii_art, node_name(node_id), "\n"); } return description; } @@ -185,4 +187,51 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } +absl::optional GetXlaClusterForNode(const Node& node) { + const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); + if (attr_value == nullptr) { + return absl::nullopt; + } + Status s = AttrValueHasType(*attr_value, "string"); + if (!s.ok()) { + return absl::nullopt; + } + return attr_value->s(); +} + +bool HasResourceInputOrOutput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +void RemoveFromXlaCluster(NodeDef* node_def) { + node_def->mutable_attr()->erase(kXlaClusterAttr); +} + +void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } + +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + GraphCycles* cycles) { + std::vector> unsafe_deps; + TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( + *graph, flib_def, resource_ops_to_ignore, &unsafe_deps)); + + // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are + // operations that interact with resource variables, must not be put in the + // same cluster. We enforce this constraint by creating a phantom node, X, + // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P + // and Q together since that would create a cycle with X. + + for (std::pair unsafe_dep : unsafe_deps) { + int phantom_node_id = cycles->NewNode(); + CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id)); + CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second)); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index bcce082aaf6044ff0654efa4d78c0f493a350d00..ba218f3315d2607c47342fdade0403678faa2362 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" @@ -44,6 +45,26 @@ bool HasForwardedRefInput(const Node& node); // the enclosing graph. Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, +// otherwise returns nullopt. +absl::optional GetXlaClusterForNode(const Node& node); + +// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(NodeDef* node_def); + +// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(Node* node); + +// Returns true if `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node); + +// Adds edges to `cycles` to prevent clustering resource operations that cannot +// be legally clustered. +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + GraphCycles* cycles); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 2cb351e1ecdb4523a8652886af156540e4736b18..65bbf3efe85ba30f44531ff6d54b041786dca0a5 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7140d47a9421ec73d0144e855b490f89569e6ae9..3aa9e9c7ed2dd3b7480f40e868c6b07192b68294 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -67,12 +67,12 @@ string XlaCompilationCache::DebugString() { string XlaCompilationCache::SignatureDebugString(const Signature& sig) { string result = sig.name; for (const auto& a : sig.arg_types) { - strings::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + absl::StrAppend(&result, ",", DataTypeString(a.first), + a.second.DebugString()); } for (const auto& v : sig.arg_values) { - strings::StrAppend(&result, "; ", v.DebugString()); + absl::StrAppend(&result, "; ", v.DebugString()); } return result; } @@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { return CompileImpl(options, function, constant_args, variable_args, ctx, compilation_result, executable, compile_options, false); } @@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); @@ -256,10 +256,10 @@ Status XlaCompilationCache::CompileImpl( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op) { CHECK_NE(executable, nullptr); - VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); + VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() @@ -310,7 +310,7 @@ Status XlaCompilationCache::CompileImpl( // cache eviction. mutex_lock entry_lock(entry->mu); if (!entry->compiled) { - VLOG(1) << "Compilation cache miss for signature: " + VLOG(2) << "Compilation cache miss for signature: " << SignatureDebugString(signature); tensorflow::Env* env = tensorflow::Env::Default(); const uint64 compile_start_us = env->NowMicros(); @@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl( entry->compiled = true; if (compile_single_op) { - entry->compilation_status = compiler.CompileSingleOp( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - signature.name, ctx, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileSingleOp(compile_options, signature.name, ctx, args, + &entry->compilation_result); } else { entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + compile_options, function, args, &entry->compilation_result); } TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index fc5f008f4f52c32d97e680784082d0e7bcb7d8eb..10ad87e38cc4d614e869782329f84351bc3b1f0b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase { const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op); // Takes `result` which has been compiled from a Tensorflow subgraph to a diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index d288d37bc75380168a31937024dd41bdbe7dce9d..b98c0cb028ff069278dceda21f4588c0da9086e5 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,7 @@ std::map GetVariables(OpKernelContext* ctx) { OptionalTensor& optional = variables[i]; optional.name = handle.name(); if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); optional.present = true; optional.value = *variable->tensor(); @@ -57,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variables); + launch_context.PopulateInputs(ctx, result, variables, + /*missing_ctx_input_prefix=*/0); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -71,13 +74,15 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(ctx->step_id()); + run_options.set_rng_seed(GetXLARandomSeed()); xla::StatusOr run_result = executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( + ctx, result, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); return Status::OK(); } @@ -175,7 +180,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, &compile_options); + result, executable, compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7e159e3171113b0d53f03bb676ac9c21db7fe77a..003c1d8081a3313fd042cdcaea14508ed1048da3 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "Host" (CPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" @@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 4ddeaebd3e42e96d46857a278451d8c97e49a725..0824c4644e3e5d8e1390b99f12de824bfcdfec24 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" @@ -26,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -100,7 +102,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(); + absl::make_unique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -146,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { } const DeviceAttributes attrs = Device::BuildDeviceAttributes( - strings::StrCat(name_prefix, "/device:", device_name, ":", - device_ordinal), + absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), - strings::StrCat("device: ", device_name, " device")); + absl::StrCat("device: ", device_name, " device")); device->reset( new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), @@ -183,14 +184,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return device_type_; } -/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, - const Metadata** metadata) { +/*static*/ Status XlaDevice::GetMetadataFromDevice( + DeviceBase* device, const XlaDevice::Metadata** metadata) { *metadata = nullptr; - XlaDevice* xla_device = - dynamic_cast(ctx->device()->UnderlyingDevice()); + XlaDevice* xla_device = dynamic_cast(device->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( - "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), + "Cannot get XLA metadata from non-XLA device \"", device->name(), "\". GetMetadata must only be called on an XLA device. Either an " "internal bug has been triggered, or an XLA-specific op has been " "placed on the wrong device."); @@ -199,6 +199,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } +/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + +/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, @@ -216,6 +226,8 @@ XlaDevice::XlaDevice( transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name << " " << this; + thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", + /*num_threads=*/1)); } XlaDevice::~XlaDevice() { @@ -262,10 +274,12 @@ Status XlaDevice::EnsureDeviceContextOk() { Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, const string& name, - xla::StreamPool::Ptr* stream, + std::shared_ptr* stream, bool* stream_was_changed) { if (!(*stream) || !(*stream)->ok()) { - TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_)); + xla::StreamPool::Ptr ptr; + TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); + *stream = std::shared_ptr(std::move(ptr)); VLOG(1) << "XlaDevice " << this << " new " << name << " " << (*stream)->DebugStreamPointers(); *stream_was_changed = true; @@ -281,8 +295,8 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, &need_new_device_context)); - se::Stream* host_to_device_stream = stream_.get(); - se::Stream* device_to_host_stream = stream_.get(); + std::shared_ptr host_to_device_stream = stream_; + std::shared_ptr device_to_host_stream = stream_; if (use_multiple_streams_) { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, @@ -290,8 +304,8 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", &device_to_host_stream_, &need_new_device_context)); - host_to_device_stream = host_to_device_stream_.get(); - device_to_host_stream = device_to_host_stream_.get(); + host_to_device_stream = host_to_device_stream_; + device_to_host_stream = device_to_host_stream_; } if (!need_new_device_context) { @@ -304,9 +318,13 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { if (device_context_) { device_context_->Unref(); } + // The XlaDeviceContext keeps a reference count to the streams, and the + // XlaDeviceContext remains live for the duration of a Executor run. This + // ensures that the streams remain live for the duration of a run, even if + // an error is encountered and the streams are replaced with new ones. device_context_ = new XlaDeviceContext( - stream_.get(), host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); + stream_, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " << device_context_; @@ -318,7 +336,7 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { // to those methods; see the bug for details. Our only saving grace at the // moment is that this race doesn't seem to occur in practice. if (use_gpu_device_info_) { - auto gpu_device_info = MakeUnique(); + auto gpu_device_info = absl::make_unique(); gpu_device_info->stream = stream_.get(); gpu_device_info->default_context = device_context_; set_tensorflow_gpu_device_info(gpu_device_info.get()); @@ -355,10 +373,6 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When Xprof profiling is off (which is the default), constructing the - // activity is simple enough that its overhead is negligible. - tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -371,6 +385,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, op_kernel->ComputeAsync(context, done); } +Status XlaDevice::Sync() { + VLOG(1) << "XlaDevice::Sync"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) return Status::OK(); + + if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + return errors::Internal("XlaDevice::Sync() failed."); + } + VLOG(1) << "XlaDevice::Sync completed"; + return Status::OK(); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -404,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } +void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { + mutex_lock lock(mu_); + sync_on_completion_ = sync_on_completion; +} + +bool XlaDevice::RequiresSyncOnCompletion() const { + mutex_lock lock(mu_); + return sync_on_completion_; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d8906419b0c406026bb7e10007b2f0a2b4832d01..0f06b3fc80b7c844dae5643127bdabba8a53b35e 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" @@ -89,6 +88,10 @@ class XlaDevice : public LocalDevice { // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. + static Status GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata); + // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. @@ -124,7 +127,7 @@ class XlaDevice : public LocalDevice { void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override { return Status::OK(); } + Status Sync() override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -148,18 +151,27 @@ class XlaDevice : public LocalDevice { // information for GPU and TPU devices. Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); + // Instructs this XlaDevice to return 'sync_on_completion' for + // RequiresSyncOnCompletion(). + void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + + bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + private: xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, - xla::StreamPool::Ptr* stream, + std::shared_ptr* stream, bool* stream_was_changed) EXCLUSIVE_LOCKS_REQUIRED(mu_); xla::StatusOr GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); - mutex mu_; + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. @@ -174,17 +186,17 @@ class XlaDevice : public LocalDevice { // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); + std::shared_ptr stream_ GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); + std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); + std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. const bool transfer_as_literal_; @@ -198,6 +210,13 @@ class XlaDevice : public LocalDevice { // Holds extra information for GPU and TPU devices, e.g. the device context. bool use_gpu_device_info_ GUARDED_BY(mu_) = false; std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); + + // Thread pool used for running closures + std::unique_ptr thread_pool_; + + // True if the device requires XlaDevice::Sync to be called on completion + // regardless of status. + bool sync_on_completion_ GUARDED_BY(mu_) = false; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 8cf198239c84c3720585f53ebc95876ce4396793..af83c792e5e11d8596c521c6a3aed332a1f42e5b 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include + +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(compute_stream), - host_to_device_stream_(host_to_device_stream), - device_to_host_stream_(device_to_host_stream), + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : stream_(std::move(compute_stream)), + host_to_device_stream_(std::move(host_to_device_stream)), + device_to_host_stream_(std::move(device_to_host_stream)), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) { + shape_representation_fn_(std::move(shape_representation_fn)), + thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); @@ -85,50 +91,44 @@ Status XlaTransferManager::TransferLiteralToDevice( const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); - if (UseMultipleStreams()) { + if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( + stream_->parent(), shaped_buffer)) { // Initially wait for the compute stream so that memory allocations are // synchronized. - host_to_device_stream_->ThenWaitFor(stream_); + host_to_device_stream_->ThenWaitFor(stream_.get()); } TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_, *literal, shaped_buffer)); + host_to_device_stream_.get(), *literal, shaped_buffer)); if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); } // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } void XlaTransferManager::TransferLiteralFromDevice( Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const { + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); + const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_, shaped_buffer, - [=, &shaped_buffer]( - xla::StatusOr > literal_or) { + device_to_host_stream_.get(), shaped_buffer, literal, + [=, &shaped_buffer](xla::Status status) { ref.Unref(); done([&]() -> Status { - TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() - << " " << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - Status status; - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - status = errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } + VLOG(1) << "Transfer from device as literal: " + << shaped_buffer.ToString(); return status; }()); }); @@ -184,12 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback( - [done]() { done(Status::OK()); }); - return; - } } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); @@ -199,16 +193,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, if (!block_status.ok()) { status = xla::InternalError( "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_, block_status.error_message().c_str()); + host_to_device_stream_.get(), block_status.error_message().c_str()); } } - xla_tensor->set_host_tensor(*cpu_tensor); - + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + } done(status); } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { @@ -232,9 +227,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); if (se::Event* event = - xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { + xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { device_to_host_stream_->ThenWaitFor(event); - xla_tensor->SetDefinedOn(device_to_host_stream_); + xla_tensor->SetDefinedOn(device_to_host_stream_.get()); } Status status; @@ -247,7 +242,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status block_status = device_to_host_stream_->BlockHostUntilDone(); if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, + "Failed to complete data transfer on stream %p: %s", stream_.get(), block_status.error_message().c_str()); } } @@ -285,14 +280,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, if (stream_ != device_to_device_stream) { // Initially wait for the compute stream so that memory allocations are // synchronized. - device_to_device_stream->ThenWaitFor(stream_); + device_to_device_stream->ThenWaitFor(stream_.get()); } } if (se::Event* event = - xla_src->GetDefinitionEvent(device_to_device_stream)) { + xla_src->GetDefinitionEvent(device_to_device_stream.get())) { device_to_device_stream->ThenWaitFor(event); - xla_src->SetDefinedOn(device_to_device_stream); + xla_src->SetDefinedOn(device_to_device_stream.get()); } auto from_iter = xla_src->shaped_buffer().buffers().begin(); @@ -304,28 +299,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - CHECK(event.Init()); - device_to_device_stream->ThenRecordEvent(&event); - xla_dst->SetDefinedOn(device_to_device_stream, std::move(event)); + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize"; + device_to_device_stream->ThenRecordEvent(event.get()); + xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); } return Status::OK(); }(); if (!status.ok()) { return done(status); } else { - stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback to avoid + // a deadlock. If done() is the callback that ends an Executor's run, the + // Executor may call XlaDevice::Sync() inside the callback. This + // deadlocks, because XlaDevice::Sync() waits for all stream activity to + // complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); } } XlaDeviceContext::XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(compute_stream, host_to_device_stream, device_to_host_stream, - client, transfer_as_literal, - std::move(shape_representation_fn)) {} + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : manager_(std::move(compute_stream), std::move(host_to_device_stream), + std::move(device_to_host_stream), client, transfer_as_literal, + std::move(shape_representation_fn), thread_pool) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -335,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 912f8d779e72f44821bc4fb25efa30bd35d01412..df824212948ac96a5df5228cecd9a8c864bbec9a 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,21 +47,23 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_.get(); } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -73,13 +75,13 @@ class XlaTransferManager { // The main compute stream of the device, used to synchronize the transfer // streams if they are set. - se::Stream* stream_; + std::shared_ptr stream_; // The stream to use for transferring data from host to device. Can be // idential to stream_, but must not be nullptr. - se::Stream* host_to_device_stream_; + std::shared_ptr host_to_device_stream_; // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. - se::Stream* device_to_host_stream_; + std::shared_ptr device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -87,6 +89,9 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -95,16 +100,18 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 6adda327f186a607b4e7371bf4c5071dd86582da..6967ad1f03fb5dd962d5b41f0c7ab1dfa42fab94 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,7 +23,11 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/function_ops.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" @@ -61,6 +65,16 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("resources"), \ KERNEL); +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("resources"), \ + KERNEL); + +#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); + #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ @@ -85,9 +99,15 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ + ResourceHandlesOp); \ REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ + ReadVariablesOp); \ REGISTER_KERNEL_BUILDER( \ Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ DestroyResourceOp); \ @@ -166,7 +186,71 @@ class XlaAssignVariableOp : public AsyncOpKernel { QueueIsClosedOp); \ \ REGISTER_KERNEL_BUILDER( \ - Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T", TYPES) \ + .HostMemory("input"), \ + RetvalOp); \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ + data::GeneratorDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ + .Device(DEVICE) \ + .HostMemory("buffer_size") \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + data::PrefetchDatasetOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ + data::IteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ + data::MakeIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ + data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ + data::IteratorGetNextSyncOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + data::IteratorToStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + data::IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); // TODO(phawkins): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 4b499b161371ecece14447b29fbf809b6e8857db..bc0db558d8d0b7c666efcfac5c4926144b830380 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -41,8 +42,8 @@ static bool IsShapeConsumerOp(const Node& node) { } // Returns true if the op can be decomposed into XLA ops for which -// there are fusable elemental implementations. -bool IsXlaFusable(const NodeDef& node) { +// there are fusible elemental implementations. +static bool IsXlaFusible(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( {// tf2xla/kernels/aggregate_ops.cc @@ -176,9 +177,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); if (device_type.type_string().find("XLA") != string::npos) continue; - // Assume all fusable ops are registered. + // Assume all fusible ops are registered. // TODO(hpucha): Check for registration if possible. - if (!IsXlaFusable(node->def())) { + if (!IsXlaFusible(node->def())) { continue; } @@ -208,6 +209,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); // TODO(hpucha): Make clustering more robust. There are two known issues that // we need to mitigate: (a) Non-resource variables can cause deadlocks @@ -324,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc index 5736760a878dc857a8558093054d0adc0f727398..68e19c8a135735a79fcabf121e619157fa22b4d8 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -71,7 +73,7 @@ TEST_F(XlaFusionOptimizerTest, Chains) { EXPECT_TRUE(clusters.find("D") == clusters.cend()); } -TEST_F(XlaFusionOptimizerTest, FusableOps) { +TEST_F(XlaFusionOptimizerTest, FusibleOps) { GraphDef graph; { GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); @@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } +TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output var_handle = + ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({})); + Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f); + Output begin = ops::Const(root.WithOpName("begin"), 0); + Output end = ops::Const(root.WithOpName("end"), 1); + Output strides = ops::Const(root.WithOpName("strides"), 1); + ops::ResourceStridedSliceAssign assign_1( + root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign); + ops::ResourceStridedSliceAssign assign_2( + root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign); + root.graph()->AddControlEdge(assign_1.operation.node(), + assign_2.operation.node()); + grappler::GrapplerItem item; + root.graph()->ToGraphDef(&item.graph); + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_NE(clusters["assign_1"], clusters["assign_2"]); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ef4466f0056ea98adc1ae6774105466af0d14293..60979556a3245f4a9984cde889835ce31154fe18 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 45745596749207189c60ee1e3dcf19b6ecb7eb5b..19e681af0c940023de2ce82b3b337babe2f3dd5a 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -15,7 +15,7 @@ limitations under the License. // Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; } REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, kExecAllTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, + kExecAllTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6134b8c6946429918a5ca37188cbff13a6cd1c79..4f6fc4e068e3ba125ddbca264c1affa1f09f5896 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -39,13 +42,14 @@ using xla::ShapedBuffer; } // anonymous namespace std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables) { + OpKernelContext* ctx, absl::Span variables) { std::map snapshot; for (int i : variables) { Var* variable = nullptr; ResourceHandle handle = HandleFromInput(ctx, i); OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); tensor.present = true; @@ -130,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map& variables) { + const std::map& variables, + int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. @@ -142,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs( const Tensor* t; for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { int arg_num = kernel->input_mapping[i]; + DCHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = kernel->xla_input_shapes[i]; if (variables.count(arg_num)) { t = &(variables.at(arg_num).value); CHECK(t); } else { - t = &(ctx->input(arg_num)); + t = &(ctx->input(arg_num - missing_ctx_input_prefix)); } if (use_multiple_streams_) { @@ -173,7 +179,7 @@ void XlaComputationLaunchContext::PopulateInputs( << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = xla::MakeUnique( + arg_buffers_[i] = absl::make_unique( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); @@ -182,9 +188,9 @@ void XlaComputationLaunchContext::PopulateInputs( } } -void XlaComputationLaunchContext::PopulateOutputs( +Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - ScopedShapedBuffer output) { + ScopedShapedBuffer output, int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -211,6 +217,15 @@ void XlaComputationLaunchContext::PopulateOutputs( output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); } + std::shared_ptr definition_event; + if (use_multiple_streams_) { + definition_event = std::make_shared(stream->parent()); + if (!definition_event->Init()) { + return errors::Internal("Failed to initialize tensor definition event."); + } + stream->ThenRecordEvent(definition_event.get()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -228,12 +243,13 @@ void XlaComputationLaunchContext::PopulateOutputs( // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); Device* device = dynamic_cast(ctx->device()); - OP_REQUIRES(ctx, device != nullptr, - errors::Internal("DeviceBase was not a Device.")); + if (device == nullptr) { + return errors::Internal("DeviceBase was not a Device."); + } ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, [&](Status status) { TF_CHECK_OK(status); }); @@ -258,34 +274,38 @@ void XlaComputationLaunchContext::PopulateOutputs( } } else { const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - se::DeviceMemoryBase buffer = output.buffer({output_num}); - if (allocate_xla_tensors_) { - Tensor* output_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); - if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + const DataType& type = kernel->outputs[i].type; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " + << DataTypeString(type); + if (type == DT_RESOURCE) { + TF_RET_CHECK(kernel->outputs[i].input_index >= 0) + << "Invalid input for outputs " << i; + ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); + } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); } } else { - // xla_tensor wasn't valid, which must mean this is a zero-element - // tensor. - CHECK_EQ(output_tensor->TotalBytes(), 0); + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); } - } else { - Tensor output_tensor = XlaTensorBuffer::MakeTensor( - ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); - ctx->set_output(i, output_tensor); + ++output_num; } - ++output_num; } if (VLOG_IS_ON(3)) { @@ -298,41 +318,40 @@ void XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); + int actual_input_index = write.input_index - missing_ctx_input_prefix; + if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { + return errors::Internal("Invalid input index for variable write."); + } se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(LookupOrCreateResource( + ctx, HandleFromInput(ctx, actual_input_index), &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); + if (variable->tensor()->dtype() != write.type) { + return errors::Internal("Mismatched type in variable write"); + } if (allocate_xla_tensors_) { Tensor output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(write.type, write.shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } *variable->tensor() = output_tensor; } else { @@ -343,6 +362,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } ++output_num; } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 1ea3fa4cf29266e8c452385226e56bd0b82622d9..326d70a027564343408df356833c97e131495da0 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { class XlaAllocator; @@ -43,7 +44,7 @@ class XlaAllocator; // resource variable is not initialized, the corresponding OptionalTensor // will have its `present` field set to false. std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables); + OpKernelContext* ctx, absl::Span variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -88,14 +89,24 @@ class XlaComputationLaunchContext { // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. All elements in kernel's + // input_mapping must be greater than or equal to `missing_ctx_input_prefix` + // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map& variables); + const std::map& variables, + int missing_ctx_input_prefix); // Given the XLA output in `output`, populate all outputs of `ctx`. - void PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. + Status PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output, + int missing_ctx_input_prefix); // Return the argument list. Only valid after PopulateInputs() has been // called. @@ -167,4 +178,4 @@ xla::ScopedShapedBuffer ExtractSubShapedBuffer( } // namespace tensorflow -#endif +#endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d777dfa5a34fb9615ddcf393ed53be1491cb70af..92ba7de1b7d32fcf693cd12a380d7a1e0d861d71 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { mutex_lock lock(mu_); - if (!definition_event_.has_value()) { + if (!definition_event_) { return nullptr; } @@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { return nullptr; } - return &*definition_event_; + return definition_event_.get(); } -void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { +void XlaTensor::SetDefinedOn(se::Stream* stream, + std::shared_ptr event) { mutex_lock lock(mu_); definition_event_ = std::move(event); streams_defined_on_ = {stream}; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index f7e401c731163200c518074f2caa6907efb1f684..d95da63405889dfd0c279b17789a2195072c7277 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -68,7 +71,7 @@ class XlaTensor { // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = - xla::MakeUnique(std::move(shaped_buffer)); + absl::make_unique(std::move(shaped_buffer)); } // Some tensors on the device may have known values on the host. We use these @@ -94,7 +97,7 @@ class XlaTensor { // Assert that the tensor's content is defined on 'stream' by the time 'event' // triggers. - void SetDefinedOn(se::Stream* stream, se::Event event); + void SetDefinedOn(se::Stream* stream, std::shared_ptr event); // Assert that the tensor's content is defined on 'stream'. This version does // not provide an event, and must be called *after* SetDefinedOn(Stream, @@ -116,13 +119,13 @@ class XlaTensor { // An optional event that is triggered when the tensor's content has been // defined. If this event is nullptr, it is assumed that the tensor's content // is always defined. - gtl::optional definition_event_; + std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. - gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); mutex mu_; }; } // namespace tensorflow -#endif +#endif // TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b7dc5d4c74cb41b5e758e8170a44090bf04e5420..3cf74fa7880c96198f9072ab7488a1cec15c9e5c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "medium", + size = "large", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -251,6 +251,7 @@ tf_xla_py_test( tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", + timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], tags = ["optonly"], deps = [ @@ -276,9 +277,10 @@ tf_xla_py_test( ], ) +# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "medium", + size = "large", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -387,6 +389,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reshape_op_test", + size = "small", + srcs = ["reshape_op_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -559,6 +574,7 @@ tf_xla_py_test( tf_xla_py_test( name = "matrix_band_part_test", size = "medium", + timeout = "long", srcs = ["matrix_band_part_test.py"], tags = ["optonly"], deps = [ @@ -566,6 +582,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -691,11 +708,7 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], disabled_backends = [ - # TODO(b/110300529): RngNormal doesn't return values with the expected variance - "cpu", "cpu_ondemand", - # TODO(b/31361304): enable RNG ops on GPU when parallelized. - "gpu", ], deps = [ ":xla_test", @@ -719,6 +732,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -964,7 +978,7 @@ tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], - tags = ["noasan"], # times out, http://b/78599043 + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1091,6 +1105,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/strings", ], ) @@ -1181,3 +1196,32 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) + +tf_xla_py_test( + name = "quantized_ops_test", + size = "small", + srcs = ["quantized_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "xla_ops_test", + size = "medium", + srcs = ["xla_ops_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index 3e3c09c66e72c4de141b64cea3c4693fabb7b2a2..b7b7fda293b69d6f0cec61d0d234277636a3670d 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -33,7 +33,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): def testBasic(self): num_updates = 4 # number of ADADELTA steps to perform for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): for grad in [0.2, 0.1, 0.01]: for lr in [1.0, 0.5, 0.1]: var0_init = [1.0, 2.0] diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index dc1625793aa44b96d3b96e175237caf96e7d7e74..69fb3ec2964a09508e612515b9e291fc14121d68 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithoutRegularizationBasic1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) @@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1_L2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index d775850a80e9f83f7b2c9f1cf8997dd50e229635..ab69319c59fb07e7ce56c3c287a50a6290effdfd 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -57,7 +57,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -83,7 +83,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 03554d6933aca39b428c6af4be0c78e2c7ccb0c9..058576b3d4b695209952158769162bb24e7ccfce 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -52,7 +53,10 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + # TODO: test fails for float16 due to excessive precision requirements. + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: + continue + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -91,7 +95,10 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + # TODO: test fails for float16 due to excessive precision requirements. + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: + continue + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -130,7 +137,10 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + # TODO: test fails for float16 due to excessive precision requirements. + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: + continue + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index c4fdbc5974319db9243eb2c323746cbaaea795f6..3ed1d41b7121f44dd7470f61180f7a7055369174 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testBasic(self): for i, dtype in enumerate(self.float_types): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -100,7 +100,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 9ec5a964cbb4dd98d2ef2d0b684872292118800f..1bc07ace23ccdc83103abe71ee11b72994c75a6d 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase): alpha=1.0, beta=0.9): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 9d3a889b1f54c813e881bb03b5275f809af1b3c8..68f52e796c283997b71abcdb9c3bd6aa19cb06fc 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase): op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") @@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase): def testArgMinMax(self): # Complex numbers do not support argmin/argmax. - minmax_types = set(self.numeric_types) - set(self.complex_types) + minmax_types = self.all_types & {np.int32, np.int64} for dtype in minmax_types: # output_type is a numpy data type that is used to specify the desired # output type of the op as well as to convert the Python number to the # array scalar of the type. - for output_type in self.int_types: + for output_type in minmax_types: self._assertOpOutputMatchesExpected( math_ops.argmax, axis=0, diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 0aafda7fb4d710f154157ee352d6616e5aa8935f..1b39d53dc0908e1fa05f766ca1e601731b26846d 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -36,7 +36,7 @@ class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") @@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase): equality_test=self.ListsAreClose) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testBinary( gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), @@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - if dtype not in self.complex_types: # min/max not supported for complex + # min/max not supported for complex + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.maximum, np.array([1, 2], dtype=dtype), @@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([[70], [14]], dtype=dtype)) # Complex support for squared_difference is incidental, see b/68205550 - if dtype not in self.complex_types: + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.squared_difference, np.array([1, 2], dtype=dtype), @@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1) + divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24) + np_result = np.true_divide(nums, divs) + np_result[:, divs[0] == 0] = 0 + self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result) + if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( gen_math_ops.floor_div, @@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testDivision(dtype) def testFloatDivision(self): @@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, 1, -1, 0], dtype=dtype)) def testIntRemainder(self): - for dtype in self.int_types: + for dtype in self.signed_int_types - {np.int8}: self._testRemainder(dtype) def testFloatRemainder(self): @@ -1010,7 +1018,38 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) - def testMirrorPad(self): + def testSymmetricMirrorPad(self): + mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") + for dtype in self.numeric_types: + self._testBinary( + mirror_pad, + np.array( + [ + [1, 2, 3], # + [4, 5, 6], # + ], + dtype=dtype), + np.array([[ + 2, + 2, + ], [3, 3]], dtype=np.int32), + expected=np.array( + [ + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + ], + dtype=dtype)) + self._testBinary( + mirror_pad, + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[0, 0], [0, 0]], dtype=np.int32), + expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + + def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: self._testBinary( @@ -1165,6 +1204,16 @@ class BinaryOpsTest(xla_test.XLATestCase): def testTile(self): for dtype in self.numeric_types: + self._testBinary( + array_ops.tile, + np.array([[6], [3], [4]], dtype=dtype), + np.array([2, 0], dtype=np.int32), + expected=np.empty([6, 0], dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[6, 3, 4]], dtype=dtype), + np.array([2, 0], dtype=np.int32), + expected=np.empty([2, 0], dtype=dtype)) self._testBinary( array_ops.tile, np.array([[6]], dtype=dtype), @@ -1362,5 +1411,47 @@ class BinaryOpsTest(xla_test.XLATestCase): [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], dtype=dtype)) + def testBroadcastTo(self): + for dtype in self.all_types: + x = np.random.randint(0, high=100, size=[2, 3]) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([2, 3], dtype=np.int32), + expected=x) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([6, 6], dtype=np.int32), + expected=np.tile(x, [3, 2])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 4, 3], dtype=np.int32), + expected=np.tile(x, [7, 2, 1])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 0, 3], dtype=np.int32), + expected=np.zeros([7, 0, 3], dtype=dtype)) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 1, 2, 9], dtype=np.int32), + expected=np.tile(x, [7, 1, 1, 3])) + self._testBinary( + array_ops.broadcast_to, + np.zeros([2, 0], dtype=dtype), + np.array([4, 0], dtype=np.int32), + expected=np.zeros([4, 0], dtype=dtype)) + + x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) + self._testBinary( + array_ops.broadcast_to, + x, + np.array((3, 7, 8, 9), dtype=np.int32), + expected=np.tile(x, (1, 7, 8, 9))) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index ef4d5f6322b7ae79b051795b5af7e6f7f1e55550..5c24db539bce5df701d8229290ddb4c20997d40a 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) def testFloat(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) @@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) def test2DInput(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase): {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) def testInvalidBoundariesOrder(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) @@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0]}) def testBoundariesNotList(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Expected list.*"): p = array_ops.placeholder(dtypes.int32) with self.test_scope(): diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 7b114d4f85d3a5cadc6af25b55c5a21f90d2a768..1d3979b21bfd915a641fabe1ef40301b3e5a17b4 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -2,90 +2,103 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) def all_backends(): - b = ["cpu"] + plugins.keys() - if cuda_is_configured(): - return b + ["gpu"] - else: - return b + b = ["cpu"] + plugins.keys() + if cuda_is_configured(): + return b + ["gpu"] + else: + return b -def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, - disabled_backends=None, **kwargs): - """Generates py_test targets, one per XLA backend. +def tf_xla_py_test( + name, + srcs = [], + deps = [], + tags = [], + data = [], + main = None, + disabled_backends = None, + **kwargs): + """Generates py_test targets, one per XLA backend. - This rule generates py_test() targets named name_backend, for each backend - in all_backends(). The rule also generates a test suite with named `name` that - tests all backends for the test. + This rule generates py_test() targets named name_backend, for each backend + in all_backends(). The rule also generates a test suite with named `name` that + tests all backends for the test. - For example, the following rule generates test cases foo_test_cpu, - foo_test_gpu, and a test suite name foo_test that tests both. - tf_xla_py_test( - name="foo_test", - srcs="foo_test.py", - deps=[...], - ) + For example, the following rule generates test cases foo_test_cpu, + foo_test_gpu, and a test suite name foo_test that tests both. + tf_xla_py_test( + name="foo_test", + srcs="foo_test.py", + deps=[...], + ) - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - tags: Tags to apply to the generated targets. - data: Data dependencies of the target. - main: Same as py_test's main attribute. - disabled_backends: A list of backends that should not be tested. Supported - values include "cpu" and "gpu". If not specified, defaults to None. - **kwargs: keyword arguments passed onto the generated py_test() rules. - """ - if disabled_backends == None: - disabled_backends = [] + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + tags: Tags to apply to the generated targets. + data: Data dependencies of the target. + main: Same as py_test's main attribute. + disabled_backends: A list of backends that should not be tested. Supported + values include "cpu" and "gpu". If not specified, defaults to None. + **kwargs: keyword arguments passed onto the generated py_test() rules. + """ + if disabled_backends == None: + disabled_backends = [] - enabled_backends = [b for b in all_backends() if b not in disabled_backends] - test_names = [] - for backend in enabled_backends: - test_name = "{}_{}".format(name, backend) - backend_tags = ["tf_xla_{}".format(backend)] - backend_args = [] - backend_deps = [] - backend_data = [] - if backend == "cpu": - backend_args += [ - "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" - ] - elif backend == "gpu": - backend_args += [ - "--test_device=XLA_GPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" - ] - backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_args += ["--test_device=" + plugins[backend]["device"], - "--types=" + plugins[backend]["types"]] - backend_tags += plugins[backend]["tags"] - backend_args += plugins[backend]["args"] - backend_deps += plugins[backend]["deps"] - backend_data += plugins[backend]["data"] - else: - fail("Unknown backend {}".format(backend)) + enabled_backends = [b for b in all_backends() if b not in disabled_backends] + test_names = [] + for backend in enabled_backends: + test_name = "{}_{}".format(name, backend) + backend_tags = ["tf_xla_{}".format(backend)] + backend_args = [] + backend_deps = [] + backend_data = [] + if backend == "cpu": + backend_args += [ + "--test_device=XLA_CPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + ] + elif backend == "gpu": + backend_args += [ + "--test_device=XLA_GPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", + ] + backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_args += [ + "--test_device=" + plugins[backend]["device"], + "--types=" + plugins[backend]["types"], + ] + backend_tags += plugins[backend]["tags"] + backend_args += plugins[backend]["args"] + backend_deps += plugins[backend]["deps"] + backend_data += plugins[backend]["data"] + else: + fail("Unknown backend {}".format(backend)) - native.py_test( - name=test_name, - srcs=srcs, - srcs_version="PY2AND3", - args=backend_args, - main="{}.py".format(name) if main == None else main, - data=data + backend_data, - deps=deps + backend_deps, - tags=tags + backend_tags, - **kwargs - ) - test_names.append(test_name) - native.test_suite(name=name, tests=test_names) + native.py_test( + name = test_name, + srcs = srcs, + srcs_version = "PY2AND3", + args = backend_args, + main = "{}.py".format(name) if main == None else main, + data = data + backend_data, + deps = deps + backend_deps, + tags = tags + backend_tags, + **kwargs + ) + test_names.append(test_name) + native.test_suite(name = name, tests = test_names) -def generate_backend_suites(backends=[]): - """Generates per-backend test_suites that run all tests for a backend.""" - if not backends: - backends = all_backends() - for backend in backends: - native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) +def generate_backend_suites(backends = []): + """Generates per-backend test_suites that run all tests for a backend.""" + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend]) diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9..a57d1dc81ea2c9c188b0a3005904738aa8156bf3 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -56,7 +56,7 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) @@ -79,7 +79,7 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = rng(dtype, output_dtype) @@ -107,7 +107,7 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index ed532db0ee5553a275192e6cc3ebf394075fa0e1..d1896a50f7037f2972cba8a4fa16cc1e2cd4fe3e 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase): def _verifyCholesky(self, x, atol=1e-6): # Verify that LL^T == x. - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder( dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b..88bd58b2da6b2892f898ad10f3467d8ce39d6388 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1], dtype=np.float32) val2 = np.array([5, 6, 7, 8], dtype=np.float32) expected = val1 + val2 - with self.test_session(): + with self.cached_session(): with self.test_scope(): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1]).astype(np.float32) val2 = np.array([5, 6, 7, 8]).astype(np.float32) expected = val1 + val2 - with self.test_session(): + with self.cached_session(): with ops.device(CPU_DEVICE): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase): # where x and z are placed on the CPU and y and w are placed on the XLA # device. If y and w are clustered for compilation, then the graph will # deadlock since the clustered graph will contain a self-loop. - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device(CPU_DEVICE): x = array_ops.placeholder(dtypes.float32, [2]) with self.test_scope(): @@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase): self.assertAllClose(result, [12., 2.], rtol=1e-3) def testHostMemory(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.int32) with self.test_scope(): y = x + 1 diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index d9ad4281477e87f79f2ecb52989ae86a5030d0cc..2d225ad226cac368042b95eae8fc29e6fd8e82e0 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest class ConcatTest(xla_test.XLATestCase): def testHStack(self): - with self.test_session(): + with self.cached_session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[4:, :], params[p2]) def testVStack(self): - with self.test_session(): + with self.cached_session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[:, 4:], params[p2]) def testInt32(self): - with self.test_session(): + with self.cached_session(): p1 = np.random.rand(2, 3).astype("i") p2 = np.random.rand(2, 3).astype("i") x1 = constant_op.constant(p1) @@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase): dtype_feed = dtypes.float32 else: dtype_feed = dtype - with self.test_session(): + with self.cached_session(): p = [] for i in np.arange(num_tensors): input_shape = shape @@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase): self._testRandom(dtypes.int32) def _testGradientsSimple(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsSimple() def _testGradientsFirstDim(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsFirstDim() def _testGradientsLastDim(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase): # Random dim to concat on concat_dim = np.random.randint(5) concat_dim_sizes = np.random.randint(1, 5, size=num_tensors) - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase): def testConcatTuple(self): c1 = np.random.rand(4, 4).astype(np.float32) c2 = np.random.rand(4, 4).astype(np.float32) - with self.test_session(): + with self.cached_session(): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) def testConcatNoScalars(self): - with self.test_session(): + with self.cached_session(): with self.test_scope(): scalar = constant_op.constant(7) dim = array_ops.placeholder(dtypes.int32) @@ -291,11 +291,46 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + # The purpose of this is to ensure that XLA on GPU will not run out of memory + # with too many arguments. + def testConcatLargeNumberOfTensors(self): + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 1001 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) @@ -309,7 +344,7 @@ class ConcatOffsetTest(xla_test.XLATestCase): class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) @@ -319,7 +354,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) @@ -329,7 +364,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index f9db103f6d0f9ea0e393a0971593552ec5c14079..af00ff287d43a8542b5a3d14eedc00c3d7aef1b7 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): @@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): @@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 31ee41f04f27d387415e9fa2c4fa70b33cab7b04..33fd983b5485e503c2fcc96db2dfdecfc41e309f 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): for padding in ["SAME", "VALID"]: for stride in [1, 2]: np.random.seed(1) @@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 1, 1, 1, 1] # Input, output: [batch, depth, height, width, channel] @@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeSame(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeValid(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): np.random.seed(1) # Make it reproducible. x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 865f60ccab46ec6829e49409508303052944e13b..9390870e07d6b5bd90dbc5c04bac0946595dcf7f 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -45,20 +45,25 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def XlaLaunchOpCount(labels): - """Count how many XlaLaunch labels are present.""" - return sum("XlaLaunch(" in x for x in labels) +class DenseLayerTest(test.TestCase): + def countXlaOps(self, labels): + """Count how many XlaCompile/XlaRun labels are present.""" + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + return xla_run_count -class DenseLayerTest(test.TestCase): def testDenseLayerAutoJit(self): """Tests dense layer compilation in auto-jit mode. - Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + Dense layer should be compiled into a single XlaCompile/XlaRun op pair in + auto-jit mode. """ - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) config = config_pb2.ConfigProto() config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) @@ -76,17 +81,17 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertEqual(1, self.countXlaOps(labels)) + self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. Dense layer with static shape input tensor should be compiled into a single - XlaLaunch op by XLA. + XlaCompile/XlaRun op pair by XLA. """ - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) @@ -100,7 +105,7 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertEqual(1, self.countXlaOps(labels)) # No need to check whether ListDiff is compiled or not because ListDiff op # is not used when input tensor shape is fully defined. @@ -110,10 +115,11 @@ class DenseLayerTest(test.TestCase): Dense layer uses shape op to get shape of input tensor if its shape is not fully defined. XLA does not cluster shape op with other operators. But in experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO XlaLaunch ops. + cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op + pairs. """ - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) @@ -127,8 +133,8 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(2, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertEqual(2, self.countXlaOps(labels)) + self.assertFalse(InLabels(labels, "MatMult")) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 98dc73e189f99b7b811487756659d89dacb97d8a..6ef8a68ca5d35d3d2f78f0cb491e7bb98ff97ac9 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=data_type).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=data_type).reshape(filter_in_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: if data_type == np.float32: tolerance = 1e-4 else: @@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=np.float32).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=np.float32).reshape(filter_in_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32) t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32) with self.test_scope(): @@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.test_session(): + with self.cached_session(): t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) t1 = array_ops.placeholder(np.float32, shape=filter_sizes) t2 = array_ops.placeholder(np.float32, shape=output_sizes) @@ -356,7 +356,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.test_session(): + with self.cached_session(): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 154e36b10e6da409606ae6022aaf53e34c8e37cc..5f01e128f0b0fa725d99b00ba3406bd50a1b8962 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index edd78153b56bb5bf1c268936fb82a60581389733..50b04daa6b9f4159a3c4bdeecaf900a5b35a833c 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): - with self.test_session() as session: + with self.cached_session() as session: index_placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices ] diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 422f36d43bf38d26f057c18da716d7e281c286af..63cee550fde9d9d4314b1541fba191df776a4da2 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -100,7 +101,7 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) @@ -122,6 +123,14 @@ class EagerTest(xla_test.XLATestCase): with self.test_scope(): self.assertAllEqual(2, array_ops.identity(2)) + def testRandomOps(self): + with self.test_scope(): + tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32) + row0 = tensor[0].numpy() + row1 = tensor[1].numpy() + # It should be very unlikely to rng to generate two equal rows. + self.assertFalse((row0 == row1).all()) + def testIdentityOnVariable(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable(True) @@ -342,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase): var = f(v) self.assertEqual(2.0, var.numpy()) + def testReturnResourceHandle(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) + + def f(v): + return v.handle + + f = function.defun(f) + handle = f(v) + self.assertAllEqual(v.numpy(), + resource_variable_ops.read_variable_op( + handle, dtypes.float32).numpy()) + + def testReturnMultipleResourceHandles(self): + with self.test_scope(): + v1 = resource_variable_ops.ResourceVariable(1.25) + v2 = resource_variable_ops.ResourceVariable(2.0) + + def f(v): + return v.handle, 3.0 * v, v2.handle, v + v2 + + f = function.defun(f) + v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) + self.assertAllEqual(v1.numpy(), + resource_variable_ops.read_variable_op( + v1_handle, dtypes.float32).numpy()) + self.assertEqual(3.75, v1_times_3.numpy()) + self.assertAllEqual(v2.numpy(), + resource_variable_ops.read_variable_op( + v2_handle, dtypes.float32).numpy()) + self.assertEqual(3.25, variable_sum.numpy()) + def testAllArgumentKinds(self): """Test a complex function that takes different argument kinds. @@ -434,7 +475,6 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertAllEqual((2, 3, 4), dz.shape.as_list()) def testNestedDefun(self): - self.skipTest('Nested defuns do not work on TPU at the moment') with self.test_scope(): @function.defun @@ -449,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase): y = two_x_plus_1(x) self.assertAllEqual([5, 7, 9], y.numpy()) + def testNestedDefunWithVariable(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + y = f(x) + + self.assertEqual(75, y.numpy()) + + def testNestedDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + def testNestedDefunInGradientTapeDifferentVars(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + v1 = resource_variable_ops.ResourceVariable(3.0) + + @function.defun + def g(x): + x = v1 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape(persistent=True) as tape: + y = f(x) + dy_v0 = tape.gradient(y, v0) + dy_v1 = tape.gradient(y, v1) + + self.assertEqual(45, y.numpy()) + self.assertEqual(9, dy_v0.numpy()) + self.assertEqual(15, dy_v1.numpy()) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 5529fdbb090315e1d7f47589777d8a538c90db2b..37061e91d161db352b388a965eb72c9c32d3d752 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase): strides = [1] + strides + [1] rates = [1] + rates + [1] - with self.test_session(): + with self.cached_session(): image_placeholder = array_ops.placeholder(dtypes.float32) with self.test_scope(): out_tensor = array_ops.extract_image_patches( diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index c48ab178bf53558084fb500b2811c6f0b77a7943..2178c4455609550226c89ceb185837768be1f622 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") @@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): expected_backprops_wrt_min = 1.0 + 2.0 expected_backprops_wrt_max = 10.0 + 11.0 - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index c64ea249ecb97991952a960a6d16e1bb3be35b17..b3e13fbaa6b33bdaa1be123be558059e96de282e 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -71,7 +71,7 @@ class FFTTest(xla_test.XLATestCase): data = np.reshape(data.astype(np.float32).view(np.complex64), shape) data = to_32bit(complex_to_input(data)) expected = to_32bit(input_to_expected(data)) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) @@ -93,7 +93,7 @@ class FFTTest(xla_test.XLATestCase): data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] expected = np.swapaxes(expected, -1, -2) expected *= window.sum() # scipy divides by window sum - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 0f64cc87cde77fbbef6c4e570879e992bc34bafa..8c7edfd277c992c35a81dd5f261256a86352254e 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -31,13 +31,13 @@ from tensorflow.python.platform import test class FIFOQueueTest(xla_test.XLATestCase): def testEnqueue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) enqueue_op.run() def testEnqueueWithShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) enqueue_correct_op.run() @@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual(1, q.size().eval()) def testMultipleDequeues(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue([1])) self.evaluate(q.enqueue([2])) @@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) def testQueuesDontShare(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue(1)) q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) @@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) with self.assertRaisesRegexp(ValueError, "must have names"): q.enqueue({"a": 12.0}) def testParallelEnqueue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testParallelDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testDequeue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elem], result) def testMultiEnqueueAndDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) elems = [(5, 10.0), (10, 20.0), (15, 30.0)] enqueue_ops = [q.enqueue((x, y)) for x, y in elems] @@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([y], y_val) def testQueueSizeEmpty(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) self.assertEqual([0], q.size().eval()) def testQueueSizeAfterEnqueueAndDequeue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) dequeued_t = q.dequeue() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b..f1b87a5ffb73bed62a80abaa152d335f64d970c5 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -29,7 +29,6 @@ from tensorflow.python.training import adagrad from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent - class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): @@ -112,7 +111,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -146,7 +145,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -174,7 +173,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -196,13 +195,17 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4) + np.array([-7.66718769, -10.91273689]), + var0.eval(), + rtol=1e-4, + bfloat16_rtol=1e-1, + bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -236,7 +239,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): weights will tend to have smaller magnitudes with this parameter set. """ for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -259,9 +262,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + + def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): + """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.1, 0.2], dtype=dtype) + + opt0 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0, + l2_shrinkage_regularization_strength=0.1) + opt1 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update0 = opt0.apply_gradients([(grads0, var0)]) + update1 = opt1.apply_gradients([(grads1, var1)]) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + update0.run() + update1.run() + + # var0 is experiencing L2 shrinkage so it should be smaller than var1 + # in magnitude. + self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + accum0 = list(opt0._slots["accum"].values())[0].eval() + accum1 = list(opt1._slots["accum"].values())[0].eval() + # L2 shrinkage should not change how we update grad accumulator. + self.assertAllCloseAccordingToType(accum0, accum1) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical @@ -273,9 +316,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivAdagradwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) @@ -284,9 +327,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivGradientDescentwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 04fba444460e714ce96205361ac02ed492206b04..b1891b918c6584abce9da382088ed0037f5319fb 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase): def testCompileTimeConstantsInDefun(self): """Tests that XLA handles compile-time constants in defuns.""" - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) def Foo(a, c, d): @@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = aval + bval * 2 - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtypes.float32, name="a") b = array_ops.placeholder(dtypes.float32, name="b") diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 132e42ac7a28d0769b0de12ea0cee6eae752b245..374942a0b339b816944ea5529e4f84134b60017b 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn from tensorflow.python.platform import test +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), +) + class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): @@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] @@ -83,7 +83,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -126,7 +126,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearning(self, data_format): self._testLearning(False, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearningWithGradientChecker(self, data_format): self._testLearning(True, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -210,7 +195,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( @@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -260,7 +240,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): var_val = np.random.random_sample(scale_shape).astype(np.float32) data_format_src = "NHWC" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 23b0aed34fb460f50c241e5a920cb4f6f613b947..7161f4ab339b6f4069dd2b02ddbc6a89973e0074 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): - with self.test_session(): + with self.cached_session(): paramsp = array_ops.placeholder(params.dtype) indicesp = array_ops.placeholder(indices.dtype) with self.test_scope(): @@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([[4], [4], [0]], np.int32))) def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): - with self.test_session(): + with self.cached_session(): params = np.ones((3, 3), dtype=np.float32) indices_empty = np.empty((0, 2), dtype=np.int32) diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3..a38e1edafe883f6d3b64e1d7f94e394cccafa2e9 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase): return data def testScalar1D(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) for dtype in self.all_tf_types: for indices in 4, [4], [1, 2, 2, 4, 5]: @@ -51,11 +51,11 @@ class GatherTest(xla_test.XLATestCase): indices_tf = constant_op.constant(indices) gather_t = array_ops.gather(params, indices_tf) gather_val = session.run(gather_t, feed_dict={params: params_np}) - np_val = params_np[indices] + np_val = constant_op.constant(params_np[indices]) self.assertAllEqual(np_val, gather_val) def testScalar2D(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -65,11 +65,12 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant(2) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, 2, axis=axis) + expected = constant_op.constant( + np.take(params_np, 2, axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -80,14 +81,15 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant([0, 1, 0, 2]) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32_Int64Indices(self): if np.int64 not in self.int_types: return - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) # The indices must be in bounds for any axis. @@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase): params: params_np, indices: indices_np }) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testHigherRank(self): @@ -114,16 +117,17 @@ class GatherTest(xla_test.XLATestCase): for axis in 0, 1, 2, 3, -1, -2: params = self._buildParams(np.random.randn(*shape), dtype) indices = np.random.randint(shape[axis], size=indices_shape) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): tf_params = array_ops.placeholder(dtype=dtype) tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) gather_value = sess.run(gather, feed_dict={tf_params: params}) - gather_np = np.take(params, indices, axis=axis) + gather_np = constant_op.constant( + np.take(params, indices, axis=axis), dtype) self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): - with self.test_session(): + with self.cached_session(): for dtype in self.numeric_tf_types: params = array_ops.placeholder(dtype=dtype) indices = array_ops.placeholder(dtype=np.int32) @@ -137,7 +141,7 @@ class GatherTest(xla_test.XLATestCase): [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) def testGatherPrecision(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) indices = np.array([1, 2, 3, 1]) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index bf986ade06b11358552ee92df3169f965ce3f534..68fdb5caf4c2a496b5058cdda40ca650484a6e0e 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -54,7 +54,7 @@ class RGBToHSVTest(xla_test.XLATestCase): inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually - with self.test_session() as sess: + with self.cached_session() as sess: batch0 = array_ops.placeholder(nptype, shape=shape) with self.test_scope(): batch1 = image_ops.rgb_to_hsv(batch0) @@ -78,7 +78,7 @@ class RGBToHSVTest(xla_test.XLATestCase): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] for nptype in self.float_types: rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv = image_ops.rgb_to_hsv(placeholder) @@ -97,7 +97,7 @@ class RGBToHSVTest(xla_test.XLATestCase): for r, g, b in rgb_flat ]) hsv_np = hsv_np.reshape(4, 4, 4, 3) - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv_op = image_ops.rgb_to_hsv(placeholder) @@ -108,7 +108,7 @@ class RGBToHSVTest(xla_test.XLATestCase): class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -146,7 +146,7 @@ class AdjustContrastTest(xla_test.XLATestCase): return y_np def _adjustContrastTf(self, x_np, contrast_factor): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(np.float32) with self.test_scope(): y = image_ops.adjust_contrast(x, contrast_factor) @@ -180,7 +180,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -198,7 +198,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -216,7 +216,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -244,7 +244,7 @@ class AdjustHueTest(xla_test.XLATestCase): return y_v.reshape(x_np.shape) def _adjustHueTf(self, x_np, delta_h): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtypes.float32) with self.test_scope(): y = gen_image_ops.adjust_hue(x, delta_h) @@ -324,7 +324,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -339,7 +339,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -378,7 +378,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): "gb_same", "rgb_same", ] - with self.test_session(): + with self.cached_session(): for x_shape in x_shapes: for test_style in test_styles: x_np = np.random.rand(*x_shape) * 255. @@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): image_np, target_shape, expected=None, - large_tolerance=False): + large_tolerance=False, + align_corners=True): if expected is None: self.fail("expected must be specified") - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): image = array_ops.placeholder(image_np.dtype) resized = gen_image_ops.resize_bilinear( - image, target_shape, align_corners=True) + image, target_shape, align_corners=align_corners) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) if large_tolerance: self.assertAllClose( @@ -433,7 +434,7 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.fail("input_shape must be specified") if expected is None: self.fail("expected must be specified") - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) resized = gen_image_ops.resize_bilinear_grad( @@ -579,14 +580,31 @@ class ResizeBilinearTest(xla_test.XLATestCase): dtype=np.float32)), large_tolerance=True) + def testNonAlignCorners3x2To6x4(self): + input_data = [[64, 32], [32, 64], [50, 100]] + expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0], + [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0], + [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [6, 4], + expected=np.array(expected_data, dtype=np.float32), + align_corners=False) + + def testNonAlignCorners6x4To3x2(self): + input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127], + [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]] + expected_data = [[127, 64], [64, 127], [50, 100]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [3, 2], + expected=np.array(expected_data, dtype=dtype), + align_corners=False) + class NonMaxSuppressionTest(xla_test.XLATestCase): def testNMS128From1024(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -596,7 +614,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -622,10 +640,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(indices_tf.size, max_output_size) def testNMS3From6Boxes(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -639,7 +653,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -671,10 +685,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): # Three boxes are selected based on IOU. # One is filtered out by score threshold. - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] @@ -686,7 +696,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -714,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) + def testNMS3Then1WithScoreMaxThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + # One is filtered out by max_output_size. + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6e0db54b7a74b284dc7d18bcbb07c178c664c1e5..de68ff0e32cd59e65094c0b7319f8ab213eed4db 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -77,11 +77,11 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" +def MetadataHasXlaOp(run_metadata): + """Returns true if there are XlaRun kernels in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaRun") class JitLaunchTest(test.TestCase): @@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node - # actually ran. However, it is sometimes possible for XlaLaunch ops to be - # constant-folded away, so the check is optional. + # + # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun + # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun + # ops to be constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] @@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase): print("Compiled Result {}".format(compiled)) if require_kernel_launch: - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) direct = sess.run(direct_op, feeds) print("Direct Result {}".format(direct)) @@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase): y = math_ops.add(x, x) return y, y - # Exercises compling a function (say, Foo) which calls another - # function (say, Bar) which is not inlined. When the compiler compiles - # Foo, it needs to symbolic execute Bar correctly regardless whether - # Bar is inlined or not. + # Exercises compiling a function (say, Foo) which calls another function + # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs + # to symbolically execute Bar correctly regardless of whether Bar is inlined + # or not. # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. @@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase): # TODO(phawkins): really we would like to test that there were exactly # two kernel launches. However, we have no reliable way to determine # that. - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) expected = np.square(np.dot(dx, dw) + db) self.assertAllClose(expected, output, rtol=1e-1) @@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) def testIgnoredArguments(self): @@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(28, out) def testLoops(self): @@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(result, np.float32(95), rtol=1e-1) def testCond(self): @@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(result, np.float32(6), rtol=1e-1) def testNestedFunction(self): @@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaCompile")) + self.assertFalse(InLabels(labels, "XlaRun")) - # Compile the backprop. One XlaLaunch. + # Compile the backprop. One XlaCompile/XlaRun pair. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaCompile")) + self.assertTrue(InLabels(labels, "XlaRun")) class ElementWiseFusionTest(test.TestCase): @@ -482,15 +485,19 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("XlaLaunch(" in x for x in labels) - return output, count + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + + return output, xla_run_count def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 45a04f0cf56e88946b946bedacb25ce6da3121b4..58622114e4f552fb71db9b040a39b57d7da0037c 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.test_session() as sess: + with self.cached_session() as sess: x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 253b45902fba2df64e5234f135b373cd2a0a7e2a..c6ad67993e8bc196a74c9a328df8c9200c92c575 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase): return output def _RunAndVerify(self, dtype): - with self.test_session(): + with self.cached_session(): # random shape shape = np.random.randint(1, 16, size=4) # Make depth at least 2 to make it meaningful @@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase): alpha = 1.0 * np.random.rand() beta = 1.0 * np.random.rand() - with self.test_session(): + with self.cached_session(): in_image = constant_op.constant(in_image_vals, shape=shape) out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 31093c65713df55390c3130b8654fdcb10fbc133..265c0b6d1412de7be3a5bf5e79129cb330ceb162 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -73,7 +73,7 @@ class LSTMTest(test.TestCase): def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar, pad_scalar): - with self.test_session() as sess: + with self.cached_session() as sess: num_inputs = 1 num_nodes = 1 @@ -156,7 +156,7 @@ class LSTMTest(test.TestCase): def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar, pad_scalar): - with self.test_session() as sess: + with self.cached_session() as sess: num_inputs = 1 num_nodes = 1 seq_length = 3 diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 0d9f99f8a6803ecae5f9233518a1768109161ac0..c61965b97fc142ce452cf28def8c937f692d2f84 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(xla_test.XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): - def _testMatrixBandPart(self, dtype, shape): - with self.test_session(): - batch_shape = shape[:-2] - mat = np.ones(shape).astype(dtype) - batch_mat = np.tile(mat, batch_shape + [1, 1]) - for lower in -1, 0, 1, shape[-2] - 1: - for upper in -1, 0, 1, shape[-1] - 1: - band_np = mat - if lower >= 0: - band_np = np.triu(band_np, -lower) - if upper >= 0: - band_np = np.tril(band_np, upper) - if batch_shape: - band_np = np.tile(band_np, batch_shape + [1, 1]) - - placeholder = array_ops.placeholder(dtype) - with self.test_scope(): - band = array_ops.matrix_band_part( - placeholder, - constant_op.constant(lower, dtype=dtypes.int32), - constant_op.constant(upper, dtype=dtypes.int32)) - feed_dict = {placeholder: batch_mat} - self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) - - def testMatrixBandPart(self): + @parameterized.parameters( + { + 'batch_shape': [], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 7 + }, + ) + def testMatrixBandPart(self, batch_shape, rows, cols): for dtype in self.float_types: - for batch_shape in [[], [2,], [1, 3, 2]]: - for rows in 1, 2, 7: - for cols in 1, 2, 7: - self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + with self.cached_session(): + mat = np.ones(batch_shape + [rows, cols]).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, rows - 1: + for upper in -1, 0, 1, cols - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 2bb8a97bdaf5836a05501ab9754433e29ae34675..94cd3eeb3179da9b920ea9f03216d602b042a639 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): clean_a = np.tril(a) if lower else np.triu(a) - with self.test_session() as sess: + with self.cached_session() as sess: placeholder_a = MakePlaceholder(a) placeholder_ca = MakePlaceholder(clean_a) placeholder_b = MakePlaceholder(b) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078..f77521a7c49dba39849869ddceb7c0e885147722 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -95,7 +95,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) var0_np = np.array([0.1, 0.2], dtype=dtype) @@ -120,7 +120,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index da08225e9fc0d5a8ec21ee9961c4758fa38628b4..a1c07fce732d3b91a7c0550545a03fdab67644d3 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase): [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) def testOneHot(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) op = array_ops.one_hot(indices, np.int32(4), @@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase): self.assertAllEqual(output, expected) def testSplitV(self): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): output = session.run( array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7..f985c5d2d96e06fc0117f3935d61b19c9e8562b1 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): output = op() result = session.run(output) self.assertAllClose(result, expected, rtol=1e-3) def testNoOp(self): - with self.test_session(): + with self.cached_session(): with self.test_scope(): output = control_flow_ops.no_op() # This should not crash. diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py index d68d32057a367776d5b70d5ac21d5618297c605d..7635f89249b7b71e5353e0b7cb1cea5c1f7bca1d 100644 --- a/tensorflow/compiler/tests/oom_test.py +++ b/tensorflow/compiler/tests/oom_test.py @@ -46,7 +46,7 @@ class OutOfMemoryTest(xla_test.XLATestCase): def test_loop(): size = int(2e8) while True: - with self.test_session(): + with self.cached_session(): # Force the compiled code to not be constant by feeding in a # parameter. p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index a75d99189b5b673261c9e48f1c5998ea0c575594..77bb839409f0c323ff6ed2c8d6bd105d3003b398 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 @@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase): self.assertEqual(8.0, sess.run(out)) def test_placeholder_with_default_fed(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 17f860db61aeda98326a6820771d67ee948b6dda..b6cdd38345b9a9f6b03e8799587e3f6ffe07b407 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase): # numbers from 1. x = np.arange(1.0, total_size + 1, dtype=np.float32) x = x.reshape(input_sizes) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = pool_func( inputs, @@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase): strides = [1] + strides + [1] total_size = np.prod(input_sizes) x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device("CPU"): diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 9fc94752ea660f7fb8b2c792180f01485ad04419..d03bd4fdbb7694bc36291faf9b845ec48e26a386 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase): # numbers from 1. x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) x = x.reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = inputs @@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase): # TODO(b/74222344): Fix nan handling for max pool grad. # x[np.random.choice(total_size)] = np.nan x = x.reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device(self.CPU_DEVICE): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 5fa7706d7294f2cffb7d24a56851be02d759335a..86536da7fed0e2309beb32fee9c7c605491592ed 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase): base=math.e, beta=0.9): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index cde87db63dbfd7c8d823c6fd0e41eee8b23735bb..c41b4171e26af4f7ad0237d7407a5b3691299595 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad class ProximalAdagradOptimizerTest(xla_test.XLATestCase): def testResourceProximalAdagradwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -60,7 +60,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertEqual(2, len(opt_vars)) def testProximalAdagradwithoutRegularization2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -84,7 +84,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) def testProximalAdagradWithL1(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -108,7 +108,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) def testProximalAdagradWithL1_L2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -151,7 +151,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): return var0.eval(), var1.eval() def testEquivAdagradwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_adagrad.ProximalAdagradOptimizer( 3.0, @@ -159,7 +159,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer( 3.0, initial_accumulator_value=0.1)) diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 11eb76871133eba8fcd24621afb03e16614fb005..3d808e6b8a71ef9fa60b671d07bfd907e9f58efc 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): def testResourceProximalGradientDescentwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) def testProximalGradientDescentwithoutRegularization2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) def testProximalGradientDescentWithL1(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) def testProximalGradientDescentWithL1_L2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): return var0.eval(), var1.eval() def testEquivGradientDescentwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer( 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0)) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 1b969ee2b3886fca6ec9951d1621ca5af6a673d8..236b1b881dcaffc1a5b0c6395f0605c1d7ef0269 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): x_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) - with self.test_session() as sess: + with self.cached_session() as sess: x_tf = array_ops.placeholder(dtype) with self.test_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) @@ -101,8 +101,8 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): @parameterized.parameters(*PARAMS) def testQR(self, rows, cols, dtype): - # TODO(b/111317468): implement full_matrices=False, test other types. - for full_matrices in [True]: + # TODO(b/111317468): Test other types. + for full_matrices in [True, False]: # Only tests the (3, 2) case for small numbers of rows/columns. for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): self._test(dtype, batch_dims + (rows, cols), full_matrices) diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..80c338513bc9ff6b8e56c5ad6b904af9e06a3715 --- /dev/null +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for quantized operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class QuantizedOpsTest(xla_test.XLATestCase): + + # Verify that quantized types can be clustered by XLA. + def testQuantizedTypeRoundtrip(self): + with self.cached_session() as session: + for dtype in self.quantized_tf_types: + in_values = np.array([1, 2, 3, 4, 5, 6]) + expected = [[1, 2], [3, 4], [5, 6]] + with self.test_scope(): + p = array_ops.placeholder(dtype=dtypes.int32) + x = math_ops.cast(p, dtype) + x = array_ops.reshape(x, [3, 2]) + + value = session.run(x, {p: in_values}) + self.assertAllEqual(value, expected) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 2f60e00c37d214d025b161310d57f9cd84884304..36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -35,11 +35,12 @@ class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): - return set(self.numeric_types) - set(self.complex_types) + return set(self.numeric_types) - set( + self.complex_types) - {np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = rng(dtype) @@ -57,7 +58,8 @@ class RandomOpsTest(xla_test.XLATestCase): def testRandomUniformIsNotConstant(self): def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=10000) + dtype = dtypes.as_dtype(dtype) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) @@ -67,13 +69,17 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): for dtype in self._random_types(): - with self.test_session() as sess: + # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is + # fixed. + if (self.device in ["XLA_GPU", "XLA_CPU" + ]) and (dtype in [dtypes.bfloat16, dtypes.half]): + continue + with self.cached_session() as sess: with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) @@ -86,16 +92,16 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - self._testRngIsNotConstant(rng, dtypes.float32) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testTruncatedNormalIsInRange(self): count = 10000000 - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: - with self.test_session() as sess: + # TODO(b/34339814): make this test work with 16 bit float types. + for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: + with self.cached_session() as sess: with self.test_scope(): - x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) + x = random_ops.truncated_normal(shape=[count], dtype=dtype) y = sess.run(x) def normal_cdf(x): @@ -124,21 +130,21 @@ class RandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=2e-3) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma actual_median = np.median(y) - self.assertAllClose(actual_median, expected_median, atol=8e-4) + self.assertAllClose(actual_median, expected_median, atol=1e-2) expected_variance = sigma**2 * (1 + ( (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) + self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) @@ -149,7 +155,7 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllEqual(set(result), set(expected)) def testShuffle2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index c0ea242044540b1cef44186880ba3cd92b8849d6..bddda6f30245d4b8281a77783ec9922d61bd3883 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/core/common_runtime/device.h" @@ -61,7 +63,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { - return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); + return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } constexpr std::array kAllXlaTypes = { @@ -107,11 +108,12 @@ class OpTestBuilder { // Sets an attribute. template - OpTestBuilder& Attr(StringPiece attr_name, T&& value); + OpTestBuilder& Attr(absl::string_view attr_name, T&& value); // Overload needed to allow {...} expressions for value. template - OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list value); + OpTestBuilder& Attr(absl::string_view attr_name, + std::initializer_list value); // Adds nodes that executes the operator under test on 'device' to 'graphdef'. // If 'use_jit' is true, marks the operator under test to be compiled by XLA. @@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); return *this; } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, std::initializer_list value) { Attr>(attr_name, std::move(value)); return *this; @@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, NodeDef* test_def = graphdef->add_node(); *test_def = node_def_; - test_def->set_name(strings::StrCat(name_prefix, "_op_under_test")); + test_def->set_name(absl::StrCat(name_prefix, "_op_under_test")); test_def->set_device(device); AddDefaultsToNodeDef(*op_def, test_def); if (use_jit) { @@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_input_", i); + string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_output_", i); + string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -275,13 +277,13 @@ class OpTest : public ::testing::Test { // Select a random element from 'candidates'. template - T Choose(gtl::ArraySlice candidates); + T Choose(absl::Span candidates); static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 256LL; // Returns true if 'dims' have a size less than tf_xla_max_tensor_size. - bool TensorSizeIsOk(gtl::ArraySlice dims); + bool TensorSizeIsOk(absl::Span dims); // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); @@ -307,11 +309,11 @@ class OpTest : public ::testing::Test { // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. Tensor RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape); + absl::Span shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. - Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomNonNegativeTensor(DataType dtype, absl::Span shape); Tensor RandomNonNegativeTensor(DataType dtype); // Returns a random subset of the integers in the range [0, rank), suitable @@ -415,7 +417,7 @@ void OpTest::Repeatedly(const std::function& fn) { } template -T OpTest::Choose(gtl::ArraySlice candidates) { +T OpTest::Choose(absl::Span candidates) { std::uniform_int_distribution d(0, candidates.size() - 1); return candidates[d(generator())]; } @@ -425,7 +427,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) { return size_distribution(generator()); } -bool OpTest::TensorSizeIsOk(gtl::ArraySlice dims) { +bool OpTest::TensorSizeIsOk(absl::Span dims) { int64 size = 1LL; for (int64 dim : dims) { size *= dim; @@ -451,7 +453,7 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, } Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -548,7 +550,7 @@ Tensor OpTest::RandomTensor(DataType dtype) { } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -726,11 +728,11 @@ bool IsClose(const complex64& x, const complex64& y, double atol, template string Str(T x) { - return strings::StrCat(x); + return absl::StrCat(x); } template <> string Str(complex64 x) { - return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); + return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } template @@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { - return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", - Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), - "atol = ", atol, " rtol = ", rtol, - " tol = ", atol + rtol * Abs(Tx(i)))); + return errors::InvalidArgument( + absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)), + " vs. ", Str(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString(), "atol = ", atol, + " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), ". x = ", x.DebugString(), "y = ", y.DebugString())); } @@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, double rtol) { if (a.dtype() != b.dtype()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Tensors have different types: ", DataTypeString(a.dtype()), " and ", DataTypeString(b.dtype()))); } if (!a.IsSameSize(b)) { - return errors::InvalidArgument(strings::StrCat( - "Tensors have different shapes: ", a.shape().DebugString(), " and ", - b.shape().DebugString())); + return errors::InvalidArgument( + absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(), + " and ", b.shape().DebugString())); } switch (a.dtype()) { @@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } string cpu_device = - LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); + LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - strings::StrCat("test", num_tests_, "_expected"), cpu_device, + absl::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, &expected_inputs, &expected_fetches); if (!status.ok()) { @@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } NodeDef* node_def; - status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"), test_device, tf_xla_test_use_jit, &graph, &node_def, &test_inputs, &test_fetches); if (!status.ok()) { @@ -1884,7 +1886,8 @@ TEST_F(OpTest, DynamicStitch) { for (int i = 0; i < n; ++i) { TensorShape shape(index_dims[i]); Tensor t = test::AsTensor( - gtl::ArraySlice(indices, pos, shape.num_elements()), shape); + absl::Span(indices).subspan(pos, shape.num_elements()), + shape); builder.Input(t); pos += t.NumElements(); } diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index cea2ec816f85e88b11e6e80c91c14fca9015f45c..132c59c32c9db0c8759bdbb31f8613c3ef88b485 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import itertools +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -30,22 +31,24 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(xla_test.XLATestCase): - +@parameterized.named_parameters(('32_bit_index', dtypes.int32), + ('64_bit_index', dtypes.int64)) +class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, + index_dtype, rtol=1e-4, atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) - index = array_ops.placeholder(dtypes.int32) + index = array_ops.placeholder(index_dtype) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) self.assertAllClose( @@ -89,22 +92,23 @@ class ReduceOpsTest(xla_test.XLATestCase): np.array([[False, True, False], [True, True, False]]), ] - def testReduceSumF32(self): - self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA) + def testReduceSumF32(self, index_dtype): + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, + index_dtype) - def testReduceSumC64(self): + def testReduceSumC64(self, index_dtype): self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, - self.COMPLEX_DATA) + self.COMPLEX_DATA, index_dtype) - def testReduceProdF32(self): + def testReduceProdF32(self, index_dtype): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceProdC64(self): + def testReduceProdC64(self, index_dtype): self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, - self.COMPLEX_DATA) + self.COMPLEX_DATA, index_dtype) - def testReduceMin(self): + def testReduceMin(self, index_dtype): def reference_min(dtype, inp, axis): """Wrapper around np.amin that returns +infinity for an empty input.""" @@ -119,9 +123,9 @@ class ReduceOpsTest(xla_test.XLATestCase): [np.float32, np.int32, np.int64]): self._testReduction(math_ops.reduce_min, functools.partial(reference_min, dtype), dtype, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceMax(self): + def testReduceMax(self, index_dtype): def reference_max(dtype, inp, axis): """Wrapper around np.amax that returns -infinity for an empty input.""" @@ -137,23 +141,25 @@ class ReduceOpsTest(xla_test.XLATestCase): [np.float32, np.int32, np.int64]): self._testReduction(math_ops.reduce_max, functools.partial(reference_max, dtype), dtype, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceMeanF32(self): + def testReduceMeanF32(self, index_dtype): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, - self.NONEMPTY_REAL_DATA) + self.NONEMPTY_REAL_DATA, index_dtype) - def testReduceMeanC64(self): + def testReduceMeanC64(self, index_dtype): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, - self.NONEMPTY_COMPLEX_DATA) + self.NONEMPTY_COMPLEX_DATA, index_dtype) - def testReduceAll(self): - self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) + def testReduceAll(self, index_dtype): + self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA, + index_dtype) - def testReduceAny(self): - self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) + def testReduceAny(self, index_dtype): + self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA, + index_dtype) class ReduceOpPrecisionTest(xla_test.XLATestCase): @@ -178,7 +184,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): """ for test_input in test_inputs: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(dtypes.int32) @@ -213,7 +219,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): bf16_max = np.float32(dtypes.bfloat16.max) f32_max = dtypes.float32.max - value = min(bf16_max, f32_max - bf16_max) + value = min(bf16_max, f32_max - bf16_max) / 2 self._testReduceSum( dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index c69b6837b0f88ced844faf3713a29a1c14c8790d..ff20ea3f4287b4666684501fa4920435a77b4183 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(operand.dtype) with self.test_scope(): output = xla.reduce_window(placeholder, init, reducer, **kwargs) diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..96e0b074754032dd64c479b5e587b664ff066e2b --- /dev/null +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): + + @parameterized.named_parameters(('32_bit_index', dtypes.int32), + ('64_bit_index', dtypes.int64)) + def testBasic(self, index_dtype): + for dtype in self.numeric_types: + with self.cached_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + shape = constant_op.constant([3, 2], dtype=index_dtype) + o = array_ops.reshape(i, shape) + params = { + i: [[1, 2, 3], [4, 5, 6]], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index d01c676e7c2fe705344f26818350c46c30451c67..392290fd92d0c7c928581422433892147374b2dd 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -32,33 +32,40 @@ class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) - for revdim in range(len(shape)): + for revdim in range(-len(shape), len(shape)): self._AssertReverseEqual([revdim], shape) def testReverseMoreThanOneDim(self): shape = (7, 5, 9, 11) + # The offset is used to test various (but not all) combinations of negative + # and positive axis indices that are guaranteed to not collide at the same + # index. for revdims in itertools.chain.from_iterable( - itertools.combinations(range(len(shape)), k) - for k in range(2, len(shape)+1)): + itertools.combinations(range(-offset, + len(shape) - offset), k) + for k in range(2, + len(shape) + 1) + for offset in range(0, len(shape))): self._AssertReverseEqual(revdims, shape) def _AssertReverseEqual(self, revdims, shape): np.random.seed(120) pval = np.random.randint(0, 100, size=shape).astype(float) - with self.test_session(): + with self.cached_session(): with self.test_scope(): p = array_ops.placeholder(dtypes.int32, shape=shape) axis = constant_op.constant( np.array(revdims, dtype=np.int32), - shape=(len(revdims),), dtype=dtypes.int32) + shape=(len(revdims),), + dtype=dtypes.int32) rval = array_ops.reverse(p, axis).eval({p: pval}) slices = [ - slice(-1, None, -1) if d in revdims else slice(None) - for d in range(len(shape))] - self.assertEqual( - pval[slices].flatten().tolist(), - rval.flatten().tolist()) + slice(-1, None, -1) + if d in revdims or d - len(shape) in revdims else slice(None) + for d in range(len(shape)) + ] + self.assertEqual(pval[slices].flatten().tolist(), rval.flatten().tolist()) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index ccfa63001653537c4d1b7140e3d745c126f9034b..abc822ef363e5d83c99bb963582662ccfce4cd6d 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): seq_lengths, truth, expected_err_re=None): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) with self.test_scope(): @@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): def testSeqLength(self): for dtype in self.all_types: - for seq_dtype in self.int_types: + for seq_dtype in self.all_types & {np.int32, np.int64}: self._testBasic(dtype, seq_dtype) diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ff8bbac911abe73f946464663984ff1626302882..8840a1329a907bddc6ef1cb6dd1c2a6d234def5c 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: for centered in [False, True]: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. var0_np = np.array([1.0, 2.0], dtype=dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 4292352e76ebcef7dbf41df7b857d2604a468117..897db384b7e8067b0460b5f344201f101a4d8479 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( feed_dict={p: x}) @@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumsum(p, axis).eval(feed_dict={p: x}) @@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, @@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) prod = math_ops.cumprod(p, axis, exclusive, reverse) tf_out = prod.eval(feed_dict={p: x}) @@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumprod(x, axis).eval(feed_dict={p: x}) @@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index f606f88545d0b6f0b52cee9b93083a6bd91169bc..693f8513bc54e30060a2e963abd504768535a50a 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase): self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) def _runScatterNd(self, indices, updates, shape): - with self.test_session(): + with self.cached_session(): updates_placeholder = array_ops.placeholder(updates.dtype) indices_placeholder = array_ops.placeholder(indices.dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 772c20fd424577c3e06eeae409f424b77b52aa8a..287bb0d84e24de3bdcde3aa4c61acee00626e88f 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" def _segmentReduction(self, op, data, indices, num_segments): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 6c4890565d2083a9493abc59bd563c4dd9fdb186..2c611a959e1d71c53e44bc92c31258153d01507d 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.slice(i, [2], [4]) @@ -40,9 +40,22 @@ class SliceTest(xla_test.XLATestCase): self.assertAllEqual([2, 3, 4, 5], result) + def testZeroSlice(self): + for dtype in self.numeric_types: + with self.cached_session(): + i = array_ops.placeholder(dtype, shape=[2]) + with self.test_scope(): + o = array_ops.slice(i, [0], [0]) + params = { + i: [0, 1], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([], result) + def test3D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) @@ -64,7 +77,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBegin(self): """Tests a slice where the start offset is not known at compile time.""" for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -88,7 +101,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBeginAndNegativeSize(self): """Tests a slice where `begin` is fed dynamically and `size` contains -1.""" for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -114,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [2], [6], [2]) @@ -127,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1DNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [6], [2], [-2]) @@ -140,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerate(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [-1, 0], [0, 3]) @@ -154,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerateNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) @@ -168,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) @@ -189,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3DNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 4, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 7ff01be3cb4848d6bb85b8ab96b3ee1db6889791..dbf4beb693ec1766e6b7b5daaed4be4e1d874fba 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import test class XlaSortOpTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -48,10 +48,6 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) @@ -60,10 +56,6 @@ class XlaSortOpTest(xla_test.XLATestCase): xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) def testTopK(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -89,10 +81,6 @@ class XlaSortOpTest(xla_test.XLATestCase): expected=[x[indices].astype(dtype), indices]) def testTopK2D(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -122,16 +110,12 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: return - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=4) @@ -144,16 +128,12 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: return - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=6) diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index c685bc548f9f6f8f7723c6f94dfd45f5420b4a67..33b84cec7188c85a3bacb20a6df29c73adbd107c 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) @@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase): def _testPad(self, inputs, block_shape, paddings, outputs): block_shape = np.array(block_shape) paddings = np.array(paddings).reshape((len(block_shape), 2)) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self.float_types: # TODO(b/68813416): Skip bfloat16's as the input type for direct is # float32 and results in a mismatch, while making testDirect provide the diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py index 3db8101c4bfbb1b53c7318a36519612984d6f179..07afd1ab3fb78d5accc52ee2382af0b9fb8079d3 100644 --- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices, class SparseToDenseTest(xla_test.XLATestCase): def testInt(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, 0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testFloat(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) self.assertAllClose(np_ans, tf_ans) def testSetValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testSetSingleValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, -1) np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def test2d(self): # pylint: disable=bad-whitespace - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) np_ans = np.array([[-1, -1, -1, -1], [-1, -1, -1, 1], @@ -78,12 +78,12 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testZeroDefault(self): - with self.test_session(): + with self.cached_session(): x = sparse_ops.sparse_to_dense(2, [4], 7).eval() self.assertAllEqual(x, [0, 0, 7, 0]) def test3d(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 np_ans[1, 3, 0] = 1 @@ -91,25 +91,25 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testBadShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): _SparseToDense([1, 3], [[5], [3]], 1, -1) def testBadValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[2,1\], " r"should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [[5], [3]], -1) def testBadNumValues(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [1, 2, 3], -1) def testBadDefault(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError("default_value should be a scalar"): _SparseToDense([1, 3], [5], [1, 2], [0]) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301..720595a159eea997be2246c4c7dad49612b257eb 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): size = array_ops.placeholder(dtypes.int32) v = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") @@ -41,7 +41,7 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") @@ -51,7 +51,7 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose(a, c1.eval({x: a})) def testMultiStack(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") c1 = gen_data_flow_ops.stack_push_v2(h1, v) @@ -66,7 +66,7 @@ class StackOpTest(xla_test.XLATestCase): def testSameNameStacks(self): """Different stacks with the same name do not interfere.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") @@ -84,14 +84,14 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose(out2, 5.0) def testCloseStack(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): size = array_ops.placeholder(dtypes.int32) h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {size: 5}) def testPushCloseStack(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") c = gen_data_flow_ops.stack_push_v2(h, v) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index d162675ef840131485128414b4a29e3cd89c8761..f3861043b27ebb131554ff49af8c986229fc15ac 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -34,11 +34,11 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): - return [dtypes.float32] + return self.float_types & {dtypes.float32, dtypes.float64} def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 for stateless_op in [ @@ -55,7 +55,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertEqual(s0 == s1, np.all(v0 == v1)) def testRandomUniformIsInRange(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( @@ -74,7 +74,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -88,7 +88,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._chi_squared(y, 10) < 16.92) def testRandomNormalIsFinite(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( @@ -111,7 +111,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -124,9 +124,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._anderson_darling(y) < 2.492) def testTruncatedNormalIsInRange(self): - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: - with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 x = stateless.stateless_truncated_normal( @@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=5e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index f332aa2e9b97e13654cf9b10588c18fed32f7ad4..78244d0b366d9128a4c59f786e4c5ac12e743b75 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -44,7 +44,7 @@ def _make_converter(dtype): class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -66,7 +66,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -86,7 +86,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWritePack(dtype) def testEmptyTensorArrayPack(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -100,7 +100,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([3, 0, 1], c0.eval().shape) def _testTensorArrayWriteConcat(self, tf_dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -121,7 +121,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWriteConcat(dtype) def _testTensorArrayUnpackRead(self, tf_dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -176,7 +176,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayUnpackReadMaybeLegacy() def _testTensorArraySplitRead(self, tf_dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -228,7 +228,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArraySplitRead(dtype) def testTensorGradArrayWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -261,7 +261,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[-2.0]], g_d2) def testTensorGradArrayDynamicWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -300,7 +300,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(3, g_vs) def testTensorGradAccessTwiceReceiveSameObject(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3, element_shape=[1, 2]) @@ -317,7 +317,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[4.0, 5.0]], d_r1_0) def testTensorArrayWriteWrongIndexOrDataTypeFails(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -331,7 +331,7 @@ class TensorArrayTest(xla_test.XLATestCase): # the first type, but try to read the other type. if len(self.float_types) > 1: dtype1, dtype2 = list(self.float_types)[:2] - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype1, tensor_array_name="foo", size=3) @@ -347,7 +347,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0.read(1) def testTensorArraySplitIncompatibleShapesFails(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -379,7 +379,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta.split([1.0], [1]).flow.eval() def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) @@ -410,7 +410,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWriteGradientAddMultipleAdds(dtype) def testMultiTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): h1 = tensor_array_ops.TensorArray( size=1, dtype=dtypes.float32, tensor_array_name="foo") w1 = h1.write(0, 4.0) @@ -425,7 +425,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllClose(9.0, r.eval()) def _testTensorArrayGradientWriteReadType(self, dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.as_dtype(dtype), tensor_array_name="foo", @@ -478,7 +478,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -513,7 +513,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWritePackConcatAndRead() def testTensorArrayReadTwice(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) ta_readtwice = tensor_array_ops.TensorArray( @@ -529,7 +529,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) def _testTensorArrayGradientUnpackRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -557,7 +557,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientUnpackRead() def testTensorArrayGradientSplitConcat(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=2) @@ -581,21 +581,21 @@ class TensorArrayTest(xla_test.XLATestCase): grad_vals[0]) def testCloseTensorArray(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c1 = ta.close() session.run(c1) def testSizeTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() self.assertAllEqual(3, s.eval()) def testWriteCloseTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -608,7 +608,7 @@ class TensorArrayTest(xla_test.XLATestCase): # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): # np_dtype = dtype.as_numpy_dtype - # with self.test_session() as session, self.test_scope(): + # with self.cached_session() as session, self.test_scope(): # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) # var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) @@ -692,7 +692,7 @@ class TensorArrayTest(xla_test.XLATestCase): # dynamic_size=True, dtype=dtypes.float32) # def testGradSerialTwoLoops(self): - # with self.test_session(), self.test_scope(): + # with self.cached_session(), self.test_scope(): # num_steps = 100 # acc = tensor_array_ops.TensorArray( # dtype=dtypes.float32, @@ -725,7 +725,7 @@ class TensorArrayTest(xla_test.XLATestCase): # self.assertAllClose(31.0, grad.eval()) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): a = array_ops.identity( np.arange( 3 * 5, dtype=np.float32).reshape(3, 5) + 1) @@ -757,7 +757,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(joint_grad_b_t, g0) def testWriteShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c0 = constant_op.constant([4.0, 5.0]) @@ -781,7 +781,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0.write(0, c2) def testPartlyUnknownShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=6) @@ -821,7 +821,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) def _testUnpackShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -846,7 +846,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testUnpackShape() def testSplitShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -867,7 +867,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def testWriteUnknownShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -879,7 +879,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def _testGradientWhenNotAllComponentsRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) x = constant_op.constant([2.0, 3.0]) w = ta.unstack(x) @@ -893,7 +893,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testGradientWhenNotAllComponentsRead() def _testTensorArrayEvalEmpty(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, infer_shape=False) with self.assertRaisesOpError( @@ -906,7 +906,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayEvalEmpty() def _testTensorArrayEvalEmptyWithDefault(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, infer_shape=True) self.assertEqual(0, ta.size().eval()) @@ -921,7 +921,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayEvalEmptyWithDefault() def testTensorArrayScatterReadAndGradients(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -946,7 +946,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) def testTensorArrayWriteGatherAndGradients(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -974,7 +974,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(expected_grad, grad_vals[0]) def testTensorArrayIdentity(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, infer_shape=False) ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4, diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index effa5a59fee7dda543b2c409dfaa27a972a55808..98a07709c611178effd7794ba58ba89770c6d77f 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") @@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase): expected=np.array([[2], [5]], dtype=dtype)) def testClipByValue(self): - # TODO(b/78258593): enable integer types here too. - for dtype in self.float_types: + for dtype in self.numeric_types - self.complex_types: test_cases = [ (np.array([2, 4, 5], dtype=dtype), dtype(7)), # (dtype(1), np.array([2, 4, 5], dtype=dtype)), # diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 73adb0d243b3b27e6c6ba669b2fd134a5976a2ec..77f6eee0cf8ddc9b76f150e1038bf66da34c5218 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase): rtol: relative tolerance for equality test. atol: absolute tolerance for equality test. """ - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(inp.dtype), inp.shape, name="a") @@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllClose(result[i], expected[i], rtol, atol) def testAllTypeOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), np.array( @@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase): def testFloatOps(self): for dtype in self.float_types: - # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018. - if dtype == np.float16 and self.device == "XLA_CPU": - continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) @@ -202,7 +199,7 @@ class UnaryOpsTest(xla_test.XLATestCase): # Disable float16 testing for now if dtype != np.float16: x = np.arange(-10, 10, 1).astype(dtype) - with self.test_session() as session: + with self.cached_session() as session: erf_x = session.run(math_ops.erf(x)) erfc_x = session.run(math_ops.erfc(x)) @@ -396,6 +393,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[True, False, True], [False, True, True]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array(0.5, dtype=dtype), + expected=np.array(np.log(np.pi) / 2, dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.lgamma, np.array( @@ -420,6 +422,19 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + # The actual result is complex. Take the real part. + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), + expected=np.array( + [ + np.log(np.pi) / 2 + np.log(2), + np.log(np.pi) / 2 - np.log(15) + np.log(8), + np.log(np.pi) / 2 - np.log(945) + np.log(32), + ], + dtype=dtype), + atol=1e-4) + self._assertOpOutputMatchesExpected( math_ops.digamma, np.array( @@ -615,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) def testNumericOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720..4ee144beb7f3243be069d59ee4a613484fe183b3 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -43,7 +43,7 @@ class WhileTest(xla_test.XLATestCase): def loop_cond(step): return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) @@ -65,7 +65,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.float32, []) with self.test_scope(): @@ -91,7 +91,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.complex64, []) with self.test_scope(): @@ -117,7 +117,7 @@ class WhileTest(xla_test.XLATestCase): del x return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 06d977b93c28792704b910c688af510bc650d2a4..28d61fb07dcb665fa0dbe3f3e566e291e24fa662 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops @@ -35,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase): [16384, 1], [1, 16384], [1, 20000, 1, 1]] for dtype in self.numeric_types: for shape in shapes: - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device("CPU"): x = array_ops.placeholder(dtype, shape) with self.test_scope(): @@ -47,8 +49,36 @@ class XlaDeviceTest(xla_test.XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testCopiesOfUnsupportedTypesFailGracefully(self): + """Tests that copies of unsupported types don't crash.""" + test_types = set([ + np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, + np.int64, np.float16, np.float32, np.float16, + dtypes.bfloat16.as_numpy_dtype + ]) + shape = (10, 10) + for unsupported_dtype in test_types - self.all_types: + with self.cached_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(unsupported_dtype, shape) + with self.test_scope(): + y, = array_ops.identity_n([x]) + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape) + inputs = inputs.astype(unsupported_dtype) + # Execution should either succeed or raise an InvalidArgumentError, + # but not crash. Even "unsupported types" may succeed here since some + # backends (e.g., the CPU backend) are happy to handle buffers of + # unsupported types, even if they cannot compute with them. + try: + sess.run(z, {x: inputs}) + except errors.InvalidArgumentError: + pass + def testControlTrigger(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() sess.run(x) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf88fc523735cc2d22e085afb83790c7ebb48e4 --- /dev/null +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -0,0 +1,340 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA op wrappers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected, + equality_fn=None): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def testAdd(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.add, + args=(np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype)), + expected=np.array([5, 7, 9], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(0,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 9], [14, 15]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(1,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 13], [10, 15]], dtype=dtype)) + + def testBroadcast(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.broadcast(x, (7, 42)), + args=(v,), + expected=np.tile(v, (7, 42, 1, 1))) + + def testShiftRightLogical(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + + def testShiftRightArithmetic(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([-1, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) + + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT, + xla_data_pb2.PrecisionConfig.HIGH, + xla_data_pb2.PrecisionConfig.HIGHEST) + + @parameterized.parameters(*PRECISION_VALUES) + def testConv(self, precision): + for dtype in set(self.float_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + def conv_1d_fn(lhs, rhs): + dnums = xla_data_pb2.ConvolutionDimensionNumbers() + num_spatial_dims = 1 + dnums.input_batch_dimension = 0 + dnums.input_feature_dimension = 1 + dnums.output_batch_dimension = 0 + dnums.output_feature_dimension = 1 + dnums.kernel_output_feature_dimension = 0 + dnums.kernel_input_feature_dimension = 1 + dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfig() + precision_config.operand_precision.extend([precision, precision]) + return xla.conv( + lhs, + rhs, + window_strides=(1,), + padding=((2, 1),), + lhs_dilation=(1,), + rhs_dilation=(2,), + dimension_numbers=dnums) + + self._assertOpOutputMatchesExpected( + conv_1d_fn, + args=( + np.array([[[3, 4, 5, 6]]], dtype=dtype), + np.array([[[-2, -3]]], dtype=dtype), + ), + expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype)) + + @parameterized.parameters(*PRECISION_VALUES) + def testDotGeneral(self, precision): + for dtype in self.float_types: + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfig() + precision_config.operand_precision.extend([precision, precision]) + return xla.dot_general( + lhs, + rhs, + dimension_numbers=dnums, + precision_config=precision_config) + + lhs = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=dtype) + rhs = np.array( + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], dtype=dtype) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array( + [ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=dtype)) + + def testNeg(self): + for dtype in self.numeric_types - {np.uint8, np.int8}: + self._assertOpOutputMatchesExpected( + xla.neg, + args=(np.array([1, 2, 3], dtype=dtype),), + expected=np.array([-1, -2, -3], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 0]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), + expected=np.array( + [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], + dtype=dtype)) + + def testReduce(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + def sum_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4])) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([12, 15, 18, 21], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([6, 22, 38], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0, 1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=dtype(66)) + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + def mul_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + mul_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([0, 45, 120, 231], dtype=dtype)) + + def testSelectAndScatter(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def add_scatter(x, y): + return x + y + + @function.Defun(dtype, dtype) + def ge_select(x, y): + return x >= y + + def test_fn(operand, source): + return xla.select_and_scatter( + operand, + window_dimensions=[2, 3, 1, 1], + window_strides=[2, 2, 1, 1], + padding=[[0, 0]] * 4, + source=source, + init_value=0, + select=ge_select, + scatter=add_scatter) + + self._assertOpOutputMatchesExpected( + test_fn, + args=(np.array( + [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6], + [0, 6, 2, 10, 2]], + dtype=dtype).reshape((4, 5, 1, 1)), + np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))), + expected=np.array( + [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0], + [0, 0, 0, 1, 0]], + dtype=dtype).reshape((4, 5, 1, 1))) + + def testTranspose(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + + def testDynamicSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_slice, + args=(np.arange(1000, + dtype=np.int32).astype(dtype).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3, 2])), + expected=np.array( + np.array([[[573, 574], [583, 584], [593, 594]], + [[673, 674], [683, 684], [693, 694]]]), + dtype=dtype)) + + def testDynamicSliceWithIncorrectStartIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7]), np.array([2, 3, 4])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^start_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and start_indices has shape \[2\].*')) + + def testDynamicSliceWithIncorrectSizeIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^size_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and size_indices has shape \[2\].*')) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 88827cb53bee7bb809d0163d6badcef17e59aa78..98a41981cf30917bc2054c19af5d8176bdfc9862 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -97,10 +97,23 @@ class XLATestCase(test.TestCase): ]) self._numeric_tf_types = set( self.int_tf_types | self._float_tf_types | self.complex_tf_types) - - self._all_types = set( - [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.quantized_tf_types = set( + dtype for dtype in self._all_tf_types if dtype.is_quantized) + + # Quantized types don't have a numpy equivalent, include them in + # all_tf_types but not in all_types. + # TODO(b/115960798): Parametrize tests on TF types instead of numpy types + # and remove all_types. + self._all_types = set(dtype.as_numpy_dtype + for dtype in self._all_tf_types + if not dtype.is_quantized) self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self.signed_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if not dtype.is_unsigned) + self.unsigned_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 61759fd2764205fab7fce11c4003e84be1be813a..ba1e3b2b4fdbb73e98105ace6571783ef780adf5 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -75,6 +76,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -88,6 +90,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -95,6 +98,10 @@ cc_library( name = "cpu_function_runtime", srcs = ["cpu_function_runtime.cc"], hdrs = ["cpu_function_runtime.h"], + visibility = [ + "//tensorflow/compiler/aot:__pkg__", + "//tensorflow/compiler/xla/service/cpu:__pkg__", + ], deps = [ # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. @@ -144,6 +151,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -181,9 +189,9 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", + ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -206,6 +214,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -215,13 +225,11 @@ cc_library( srcs = [ "literal_util.cc", "shape_util.cc", - "str_util.cc", "type_util.cc", ], hdrs = [ "literal_util.h", "shape_util.h", - "str_util.h", "type_util.h", ], visibility = [":friends"], @@ -233,6 +241,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", ], ) @@ -250,6 +259,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -274,6 +284,7 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -282,6 +293,8 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -300,6 +313,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -347,6 +361,7 @@ tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ + ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", @@ -358,6 +373,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", @@ -367,19 +383,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - ], -) - -tf_cc_test( - name = "str_util_test", - srcs = [ - "str_util_test.cc", - ], - deps = [ - ":common", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -434,25 +438,116 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "functionalize_control_flow_util", + srcs = [ + "functionalize_control_flow_util.cc", + ], + hdrs = [ + "functionalize_control_flow_util.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "functionalize_cond", + srcs = [ + "functionalize_cond.cc", + ], + hdrs = [ + "functionalize_cond.h", + ], + deps = [ + ":functionalize_control_flow_util", + ":tf2xla_util", + "//tensorflow/compiler/jit:union_find", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) cc_library( name = "functionalize_control_flow", - srcs = ["functionalize_control_flow.cc"], - hdrs = ["functionalize_control_flow.h"], + srcs = [ + "functionalize_control_flow.cc", + ], + hdrs = [ + "functionalize_control_flow.h", + ], + deps = [ + ":functionalize_cond", + ":functionalize_control_flow_util", + ":functionalize_while", + ":tf2xla_util", + "//tensorflow/compiler/jit:union_find", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "functionalize_control_flow_pass_registration", + srcs = [ + "functionalize_control_flow_pass_registration.cc", + ], + deps = [ + ":functionalize_control_flow", + ], + alwayslink = 1, +) + +cc_library( + name = "functionalize_while", + srcs = [ + "functionalize_while.cc", + ], + hdrs = [ + "functionalize_while.h", + ], deps = [ + ":functionalize_cond", + ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -460,6 +555,32 @@ tf_cc_test( name = "functionalize_control_flow_test", srcs = ["functionalize_control_flow_test.cc"], deps = [ + ":functionalize_control_flow", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "functionalize_cond_test", + srcs = ["functionalize_cond_test.cc"], + deps = [ + ":functionalize_cond", ":functionalize_control_flow", ":test_util", "//tensorflow/cc:cc_ops", @@ -477,6 +598,7 @@ tf_cc_test( "//tensorflow/core:resource_variable_ops_op_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -489,6 +611,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -503,3 +626,38 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "resource_operation_table", + srcs = ["resource_operation_table.cc"], + hdrs = ["resource_operation_table.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "resource_operation_table_test", + srcs = ["resource_operation_table_test.cc"], + deps = [ + ":resource_operation_table", + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], + deps = [ + "//tensorflow/core:core_cpu", + ], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index de1008803d69fefa415c7bdbe6c27a62e625b417..027ca6d2d2f616177d91d9d57d1ff373bab2a754 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -23,41 +23,44 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { - // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_args) { - // Operators that don't look at the data of their inputs, just the shapes. - const std::unordered_set metadata_ops = { - "Rank", - "Shape", - "ShapeN", - "Size", - }; + std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + std::function edge_filter) { + std::vector compile_time_const_nodes_impl; + if (compile_time_const_nodes) { + CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); + } else { + compile_time_const_nodes_impl.resize(g.num_node_ids()); + compile_time_const_nodes = &compile_time_const_nodes_impl; + } Status status; - std::unordered_set must_be_const; - auto visit = [&status, &metadata_ops, &must_be_const, - compile_time_const_args](Node* node) { + auto visit = [&](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. - if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return; + if (XlaOpRegistry::IsMetadataOp(node->type_string())) { + return; + } // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. - if (must_be_const.find(node) != must_be_const.end()) { + if ((*compile_time_const_nodes)[node->id()]) { if (node->type_string() == "_Arg") { int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - compile_time_const_args->at(index) = true; + if (compile_time_const_arg_indices) { + (*compile_time_const_arg_indices)[index] = true; + } return; } for (const Edge* pred : node->in_edges()) { - if (!pred->IsControlEdge()) { - must_be_const.insert(pred->src()); + if (!pred->IsControlEdge() && edge_filter(*pred)) { + (*compile_time_const_nodes)[pred->src()->id()] = true; } } return; @@ -79,8 +82,9 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && - edge->dst_input() < name_range->second.second) { - must_be_const.insert(edge->src()); + edge->dst_input() < name_range->second.second && + edge_filter(*edge)) { + (*compile_time_const_nodes)[edge->src()->id()] = true; } } } @@ -88,7 +92,8 @@ Status BackwardsConstAnalysis(const Graph& g, // Post-order traversal visits nodes in reverse topological order for an // acyclic graph. - DFS(g, {}, visit); + DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{}, + [](const Edge& edge) { return !edge.src()->IsNextIteration(); }); return status; } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 634b97d7e3760c0344c948a56353ade243284aa6..49b3c6d413c6b637fa825bf182be7cc36e49b6c8 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -23,10 +23,22 @@ limitations under the License. namespace tensorflow { -// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that -// must be compile-time constants. -Status BackwardsConstAnalysis(const Graph& graph, - std::vector* compile_time_const_args); +// Backwards dataflow analysis that finds nodes in a graph that must be +// compile-time constants for us to be able to lower the graph to XLA. +// +// The indices of the arguments to `graph` that must be constant are returned in +// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not +// null. +// +// The ids of the nodes in `graph` that must be constant are returned in +// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. +// +// Only propagate const-ness along edges for which `edge_filter` returns true. +Status BackwardsConstAnalysis(const Graph& g, + std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + std::function edge_filter = + [](const Edge& e) { return true; }); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 992b12c06db5efc0ae54284d0ea77017c1c79aca..56065be894697bc72ecc0089c665c19aafee7bf8 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) { auto c = ops::Reshape(root, arg2, b); auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3)); - Graph graph(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph)); + FixupSourceAndSinkEdges(root.graph()); std::vector const_args(4, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + std::vector const_nodes(root.graph()->num_node_ids(), false); + TF_ASSERT_OK( + BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes)); // Arg 0 doesn't need to be constant since the graph only uses its shape. // Arg 1 must be constant because it flows to the shape argument of a Reshape. // Arg 2 is used only as the value input to a Reshape and need not be const. // Arg 3 is used as the reduction-indices argument to Sum and must be const. EXPECT_EQ(const_args, std::vector({false, true, false, true})); + + EXPECT_FALSE(const_nodes[arg0.node()->id()]); + EXPECT_TRUE(const_nodes[arg1.node()->id()]); + EXPECT_FALSE(const_nodes[arg2.node()->id()]); + EXPECT_TRUE(const_nodes[arg3.node()->id()]); } // Regression test for a case where the backward const analysis did @@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector const_args(3, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector({true, true, false})); } @@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector const_args(2, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector({false, true})); } diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc index 2ffad2af8cfe621f0cbbdd8a9484ef2dfdf1b129..fcc4095e39673b786544984a41988c3e9c5b0efb 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc @@ -55,19 +55,26 @@ size_t align_to(size_t n, size_t align) { } // namespace namespace cpu_function_runtime { -size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) { +size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params) { size_t total = 0; for (size_t i = 0; i < n; ++i) { - if (sizes[i] > 0) { - total += align_to(sizes[i], kAlign); + bool should_allocate = + buffer_infos[i].is_temp_buffer() || + (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + + if (should_allocate) { + total += align_to(buffer_infos[i].size(), kAlign); } } return total; } -void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, +void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params, void** bufs, bool annotate_initialized) { - const size_t total = AlignedBufferBytes(sizes, n); + const size_t total = + AlignedBufferBytes(buffer_infos, n, allocate_entry_params); void* contiguous = nullptr; if (total > 0) { contiguous = aligned_malloc(total, kAlign); @@ -79,13 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] < 0) { - // bufs[i] is either a constant, an entry parameter or a thread local - // allocation. - bufs[i] = nullptr; - } else { + bool should_allocate = + buffer_infos[i].is_temp_buffer() || + (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + if (should_allocate) { bufs[i] = reinterpret_cast(pos); - pos += align_to(sizes[i], kAlign); + pos += align_to(buffer_infos[i].size(), kAlign); + } else { + bufs[i] = nullptr; } } return contiguous; diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h index c7b4559c65731d1c4f4ea41e8be173ba89fe359c..dfc1e8b8aebcf3142e9f61f60171c6b58634c71d 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.h +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -18,29 +18,142 @@ limitations under the License. #include "tensorflow/core/platform/types.h" +#include + namespace tensorflow { namespace cpu_function_runtime { +// Stores information about one buffer used by an XLA:CPU compiled function. +// These buffers are used for holding inputs to the computation, outputs from +// the computation and as temporary scratch space. +class BufferInfo { + public: + // Creates a BufferInfo from a serialized encoding generated by `Encode`. + explicit BufferInfo(std::pair encoding) + : entry_param_number_(encoding.second) { + Kind kind; + uint64 size; + Unpack(encoding.first, &kind, &size); + kind_ = kind; + size_ = size; + } + + // Returns true if this buffer stores a constant. These never need to be + // allocated by the runtime. + bool is_constant() const { return kind() == Kind::kConstant; } + + // Returns true if this buffer stores an entry parameter. These may or may + // not need to be allocated by the runtime, depending on + // XlaCompiledCpuFunction::AllocMode. + bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; } + + // Returns the entry parameter number of this buffer. + uint64 entry_parameter_number() const { + assert(is_entry_parameter()); + return entry_param_number_; + } + + // Returns true if this buffer is temporary scratch space required by the XLA + // computations. These are always allocated by the runtime. + bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; } + + // Returns true if this buffer is allocated on the C stack or into registers. + // These buffers are never allocated by the runtime. + bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; } + + // Returns the size for this buffer. + uint64 size() const { return size_; } + + // Encodes this BufferInfo into two 64 bit integers that can be used to + // reconstruct the BufferInfo later using the constructor. We need this + // because we use BufferInfo in places where using protocol buffers would + // negatively impact binary size. + std::pair Encode() const { + static_assert(sizeof(*this) == 16, ""); + uint64 upper = Pack(kind(), size_); + uint64 lower = entry_param_number_; + return {upper, lower}; + } + + bool operator==(const BufferInfo& buffer_info) const { + if (kind() != buffer_info.kind() || size() != buffer_info.size()) { + return false; + } + return !is_entry_parameter() || + entry_parameter_number() == buffer_info.entry_parameter_number(); + } + + // Factory methods: + + static BufferInfo MakeTempBuffer(uint64 size) { + return BufferInfo(Kind::kTempBuffer, /*size=*/size, + /*entry_param_number=*/-1); + } + static BufferInfo MakeConstant(uint64 size) { + return BufferInfo(Kind::kConstant, /*size=*/size, + /*entry_param_number=*/-1); + } + static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) { + return BufferInfo(Kind::kEntryParameter, /*size=*/size, + /*entry_param_number=*/param_number); + } + static BufferInfo MakeOnStackBuffer(uint64 size) { + return BufferInfo(Kind::kOnStackBuffer, /*size=*/size, + /*entry_param_number=*/-1); + } + + private: + BufferInfo() = default; + + enum class Kind : unsigned { + kConstant, + kTempBuffer, + kEntryParameter, + kOnStackBuffer + }; + + Kind kind() const { return static_cast(kind_); } + + explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number) + : kind_(kind), size_(size), entry_param_number_(entry_param_number) {} + + static uint64 Pack(Kind kind, uint64 size) { + return (static_cast(size) << 2) | static_cast(kind); + } + + static void Unpack(uint64 packed, Kind* kind, uint64* size) { + *size = packed >> 2; + *kind = static_cast((packed << 62) >> 62); + } + + Kind kind_ : 2; + uint64 size_ : 62; + int64 entry_param_number_; +}; // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. constexpr size_t kAlign = 64; -// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. Each buffer is aligned to -// kAlign byte boundaries. -size_t AlignedBufferBytes(const intptr_t* sizes, size_t n); +// AlignedBufferBytes returns the sum of the size of each buffer in +// `buffer_infos`, skipping constants, on-stack buffers and, if +// allocate_entry_params is false, entry parameters. There are `n` entries in +// `buffer_infos`. Each buffer is aligned to kAlign byte boundaries. +size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params); // MallocContiguousBuffers allocates buffers for use by the entry point -// generated by tfcompile. `sizes` is an array of byte sizes for each buffer, -// where -1 causes the buffer pointer to be nullptr. There are `n` entries in -// `sizes`. If `annotate_initialized` is set, the allocated memory will be -// annotated as having been initialized - this is useful when allocating -// temporary buffers. +// generated by tfcompile. There are `n` entries in `buffer_infos`. If +// `annotate_initialized` is set, the allocated memory will be annotated as +// having been initialized - this is useful when allocating temporary buffers. +// If allocate_entry_params is true then allocates temp buffers and entry +// parameters, otherwise allocated only temp buffers. Slots in `bufs` +// corresponding to unallocated buffers are set to nullptr. // // A single contiguous block of memory is allocated, and portions of it are // parceled out into `bufs`, which must have space for `n` entries. Returns // the head of the allocated contiguous block, which should be passed to // FreeContiguous when the buffers are no longer in use. -void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, +void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params, void** bufs, bool annotate_initialized); // FreeContiguous frees the contiguous block of memory allocated by diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index f4f27a156261ea6872777cef76ecaf7dd7eebe0d..8ca628c4eb6700d7184899bc1753dd6c6aa392b0 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -21,6 +21,8 @@ limitations under the License. namespace tensorflow { namespace { +using cpu_function_runtime::BufferInfo; + TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. @@ -30,20 +32,51 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); } +std::vector SizesToBufferInfos(const intptr_t* sizes, size_t n) { + std::vector buffer_infos; + std::transform(sizes, sizes + n, std::back_inserter(buffer_infos), + [&](intptr_t size) { + if (size == -1) { + // Use a dummy on-stack buffer allocation to indicat the + // the current slot does not need an allocation. + int64 on_stack_buffer_size = 4; + return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size); + } + return BufferInfo::MakeTempBuffer(size); + }); + return buffer_infos; +} + +// Simple wrappers to make writing tests more ergonomic. + +size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) { + std::vector buffer_infos = SizesToBufferInfos(sizes, n); + return AlignedBufferBytes(buffer_infos.data(), n, + /*allocate_entry_params=*/false); +} + +void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n, + void** bufs, bool annotate_initialized) { + std::vector buffer_infos = SizesToBufferInfos(sizes, n); + return MallocContiguousBuffers(buffer_infos.data(), n, + /*allocate_entry_params=*/false, bufs, + annotate_initialized); +} + TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { - EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0); + EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0); static constexpr intptr_t sizesA[1] = {-1}; - EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -56,15 +89,14 @@ void* add_ptr(void* base, uintptr_t delta) { // free. We also check the contiguous property. TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test empty sizes. - void* base = - cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false); + void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with 0 sum. static constexpr intptr_t sizesA[1] = {-1}; void* bufA[1]; - base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false); + base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false); EXPECT_EQ(base, nullptr); EXPECT_EQ(bufA[0], nullptr); cpu_function_runtime::FreeContiguous(base); @@ -72,7 +104,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test non-empty sizes with non-0 sum. static constexpr intptr_t sizesB[1] = {3}; void* bufB[1]; - base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false); + base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufB[0], add_ptr(base, 0)); char* bufB0_bytes = static_cast(bufB[0]); @@ -84,7 +116,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test non-empty sizes with non-0 sum, and annotate_initialized. static constexpr intptr_t sizesC[1] = {3}; void* bufC[1]; - base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true); + base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true); EXPECT_NE(base, nullptr); EXPECT_EQ(bufC[0], add_ptr(base, 0)); char* bufC0_bytes = static_cast(bufC[0]); @@ -96,7 +128,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test mixed sizes. static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; void* bufD[7]; - base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false); + base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); @@ -117,5 +149,23 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { cpu_function_runtime::FreeContiguous(base); } +void CheckRoundTripIsOk(const BufferInfo& buffer_info) { + BufferInfo round_trip(buffer_info.Encode()); + ASSERT_EQ(round_trip, buffer_info); +} + +TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) { + CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0)); + CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4)); + CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0)); + CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4)); + CheckRoundTripIsOk(BufferInfo::MakeConstant(0)); + CheckRoundTripIsOk(BufferInfo::MakeConstant(4)); + CheckRoundTripIsOk( + BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4)); + CheckRoundTripIsOk( + BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 24616c01c7e54b2e8662457ca6af23a0bc563e08..380c6a7e23da92d949b26876836b999bf6406c6c 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -52,9 +52,9 @@ string MakeUniqueFilename(string name) { string filename = name; if (count > 0) { - strings::StrAppend(&filename, "_", count); + absl::StrAppend(&filename, "_", count); } - strings::StrAppend(&filename, ".pbtxt"); + absl::StrAppend(&filename, ".pbtxt"); return filename; } @@ -69,7 +69,7 @@ string WriteTextProtoToUniqueFile( << proto_type << ": " << status; return "(unavailable)"; } - string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); status = WriteTextProto(Env::Default(), filepath, proto); if (!status.ok()) { LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc new file mode 100644 index 0000000000000000000000000000000000000000..db256e577a1f3dd38e04d102f60182023b9d43b2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -0,0 +1,1384 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" + +using xla::StatusOr; + +namespace tensorflow { +namespace functionalize_cond { + +// TODO(jpienaar): Move to OutputTensor. +string DebugString(const OutputTensor& tensor) { + return absl::StrCat(tensor.node->name(), ":", tensor.index); +} + +string Branch_Name(BranchType b) { + switch (b) { + case BranchType::kElseBranch: + return "else"; + case BranchType::kThenBranch: + return "then"; + case BranchType::kBoth: + return "both"; + case BranchType::kNeither: + return "neither"; + } +} + +string DebugString(StateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "{}"; + using value_type = StateMap::CondState::value_type; + return absl::StrCat( + "{", + absl::StrJoin(*cond_state, ", ", + [](string* output, const value_type& pred_branch) { + const OutputTensor& pred = pred_branch.first; + const BranchType& branch = pred_branch.second; + if (branch == BranchType::kNeither) + absl::StrAppend(output, "d"); + else + absl::StrAppend(output, "s(", DebugString(pred), ",", + Branch_Name(branch), ")"); + }), + "}"); +} + +// Returns the predicate of a switch. +Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { + const Edge* pred_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge)); + // The predicate can be preceded by a identity node. Look through + // identity nodes to predicate. + while (pred_edge->src()->IsIdentity()) { + TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge)); + } + *pred = OutputTensor(pred_edge->src(), pred_edge->src_output()); + return Status::OK(); +} + +Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { + const Edge* val_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); + *val = OutputTensor(val_edge->src(), val_edge->src_output()); + return Status::OK(); +} + +bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, + const OutputTensor& rhs) const { + return (lhs.node->id() < rhs.node->id()) || + (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index); +} + +struct CondStateLess { + bool operator()(const StateMap::CondState::value_type& lhs, + const StateMap::CondState::value_type& rhs) const { + if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first)) + return true; + if (lhs.first.node->id() == rhs.first.node->id() && + lhs.first.index == rhs.first.index) + return lhs.second < rhs.second; + return false; + } +}; + +StateMap::StateMap(Graph* graph) { + node_to_condid_map_.resize(graph->num_node_ids()); + node_to_ancestorid_map_.resize(graph->num_node_ids()); + // Initialize the dead state (empty state is designated with a nullptr). + dead_id_ = GetCondId( + {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)}); +} + +bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; } + +bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; } + +size_t StateMap::Hash::operator()(const StateMap::CondState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second)); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second))); + } + return h; +} + +size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = hash()(*it); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, hash()(*it)); + } + return h; +} + +// CondArgNode represents a input to the conditional and its corresponding +// switch nodes. +struct CondArgNode { + explicit CondArgNode(Node* src, int src_output) + : src(src), src_output(src_output) {} + + string ToString() const { + return absl::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); + } + + Node* src; + int src_output; + std::array branch_copy; + std::vector switches; +}; +using CondArgNodes = std::vector; + +string DebugString(const CondArgNodes& nodes) { + return absl::StrCat( + "[", + absl::StrJoin(nodes, ", ", + [](string* output, const CondArgNode& node) { + absl::StrAppend(output, node.ToString()); + }), + "]"); +} + +StateMap::CondId StateMap::LookupCondId(const Node* node) const { + if (node->id() < node_to_condid_map_.size()) + return node_to_condid_map_[node->id()]; + return added_node_condid_mapping_.at(node->id()); +} + +StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) { + if (state.empty()) return nullptr; + return &*condstate_set_.insert(state).first; +} + +void StateMap::ResetCondId(const Node* node, StateMap::CondId id) { + if (node->id() < node_to_condid_map_.size()) + node_to_condid_map_[node->id()] = id; + else + added_node_condid_mapping_[node->id()] = id; +} + +StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const { + if (node->id() < node_to_ancestorid_map_.size()) + return node_to_ancestorid_map_[node->id()]; + return added_node_ancestorid_mapping_.at(node->id()); +} + +StateMap::AncestorId StateMap::GetAncestorId( + const StateMap::AncestorState& state) { + if (state.empty()) return nullptr; + return &*ancestorstate_set_.insert(state).first; +} + +void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { + if (node->id() < node_to_ancestorid_map_.size()) + node_to_ancestorid_map_[node->id()] = id; + else + added_node_ancestorid_mapping_[node->id()] = id; +} + +void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } + +string StateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupCondId(node)); +} + +string StateMap::CondStateToString(StateMap::CondId id) const { + return DebugString(id); +} + +string StateMap::AncestorStateToString(const Node* node) const { + if (auto id = LookupAncestorId(node)) return NodesToString(*id); + return "{}"; +} + +FunctionalizeCond::FunctionalizeCond(Graph* graph, + FunctionLibraryDefinition* library) + : state_map_(graph), library_(library), graph_(graph) {} + +// Class representing the merge/switch nodes that will become a conditional. +class Conditional { + public: + Conditional(OutputTensor predicate, FunctionalizeCond* parent, + StateMap* cond_state_map); + + // Adds merge node that is part of this conditional. + Status AddMerge(Node* m); + + // Constructs an If node from the merge nodes. + Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library); + + private: + // Extracts the then/else bodies: creates new graphs with the nodes + // corresponding to the nodes in the then/else branches as of this conditional + // as function bodies. + Status ExtractBodies(Graph* graph); + + // Builds the arguments that are the input to the If. + Status BuildArgumentNodes(); + + // Builds the If node for the extracted bodies with the given predicate. + Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); + + // Adds input edges to If node. + Status AddInputEdges(Graph* graph); + + // Adds output edges from If node. + Status AddOutputEdges(Graph* graph); + + // Adds switch node that is part of this conditional. + Status AddSwitch(Node* s); + + // Adds a switch node along the edge and rewire the edge to go via the switch. + Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph); + + // Internal name of conditional. The name is based on the first merge node + // added. + string name() const; + + // The FunctionalizeCond instance that created this. + FunctionalizeCond* parent_; + + // Mapping between nodes and their cond state. + StateMap* state_map_; + + // The predicate of the conditional. + OutputTensor predicate_; + + // The predicate of the switches of the conditional. This may be different + // than predicate (which is initialized from the original graph) as the + // predicate could be the output of a newly created If node. + OutputTensor switch_predicate_; + + // Switch nodes in graph that are part of this conditional. + std::set switches_; + + // Merge nodes in graph that are part of this conditional. + std::set merges_; + + // Vector of control inputs from outside the conditional to a node inside. + std::vector external_control_inputs_; + std::vector external_control_outputs_; + + // Graphs corresponding to the then and else branch. + std::array, 2> bodies_; + + // Maps from graph_ to the branch body's graph. + std::array, 2> node_maps_; + + // The argument nodes created for the switches. + CondArgNodes cond_arg_nodes_; + + // The constructed If node. + Node* if_node_ = nullptr; + + // Whether the merge nodes of this conditional have been replaced. + bool replaced_ = false; +}; + +Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, + StateMap* cond_state_map) + : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {} + +Status Conditional::AddMerge(Node* m) { + merges_.insert(m); + return Status::OK(); +} + +Status Conditional::AddSwitch(Node* s) { + VLOG(5) << "Adding switch " << s->DebugString(); + OutputTensor predicate; + TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate)); + if (switch_predicate_.node == nullptr) switch_predicate_ = predicate; + if (!(switch_predicate_ == predicate)) { + return errors::InvalidArgument( + "Merge nodes ", NodesToString(merges_), + " directly dominated by switch nodes with different predicates (", + DebugString(switch_predicate_), " vs ", DebugString(predicate), ")."); + } + switches_.insert(s); + return Status::OK(); +} + +Status Conditional::BuildArgumentNodes() { + VLOG(1) << "Build function arguments"; + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second)); + } + }; + + std::unordered_map, int, Hash> input_index; + for (Node* switch_node : switches_) { + const Edge* e; + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); + std::pair key = std::make_pair(e->src(), e->src_output()); + if (input_index.find(key) == input_index.end()) { + input_index[key] = cond_arg_nodes_.size(); + cond_arg_nodes_.emplace_back(key.first, key.second); + } + cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node); + } + VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_); + + int arg_count = 0; + for (CondArgNode& cond_arg_node : cond_arg_nodes_) { + DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output); + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + TF_RETURN_IF_ERROR( + NodeBuilder(absl::StrCat("_Arg", arg_count), + FunctionLibraryDefinition::kArgOp) + .Attr("T", dtype) + .Attr("index", arg_count) + .Finalize(bodies_[branch_index].get(), + &cond_arg_node.branch_copy[branch_index])); + } + for (Node* node : cond_arg_node.switches) { + for (const Edge* e : node->out_edges()) { + if (e->IsControlEdge()) continue; + int branch_index = e->src_output(); + Node* src_copy = cond_arg_node.branch_copy[branch_index]; + Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; + + // The graph may contain dead switch nodes, + if (dst_copy == nullptr) continue; + + TF_RET_CHECK(dst_copy != nullptr) + << "Unable to find copied node for " << e->dst()->DebugString() + << " on branch " << Branch_Name(BranchType(branch_index)); + // If the input goes directly to a merge then the merge has + // been replaced by a retval so the dst input is 0 instead of + // dst_input. + int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input(); + bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); + } + } + ++arg_count; + } + + // Verify that all retvals have an input. + // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have + // input. + for (Node* m : merges_) { + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + bool has_input = false; + for (auto e : node_maps_[static_cast(branch)][m->id()]->in_edges()) { + if (!e->IsControlEdge()) { + has_input = true; + break; + } + } + if (!has_input) { + return errors::Internal( + "Failed to functionalize control flow with merge ", + FormatNodeForError(*m), " that doesn't have input on ", + Branch_Name(branch), " branch."); + } + } + } + + return Status::OK(); +} + +Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph) { + // Previously we had edge: + // src:src_output ---- edge ----> dst:dst_input + // post this we have (in graph) + // src:src_output --> switch --- new_edge --> dst:dst_input + + // TODO(jpienaar): One could keep a map caching the extra switch nodes added + // to avoid adding another switch to feed a value for which a switch was + // already added. + Node* switch_node; + Node* src = edge->src(); + int src_output = edge->src_output(); + TF_RETURN_IF_ERROR( + NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")), + "Switch") + .Input(src, src_output) + .Input(const_cast(predicate_.node), predicate_.index) + .Finalize(graph, &switch_node)); + state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src)); + state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src)); + + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + graph->AddEdge(switch_node, static_cast(branch), dst, dst_input); + return AddSwitch(switch_node); +} + +Status Conditional::ExtractBodies(Graph* graph) { + VLOG(2) << "Extracting bodies for " << name(); + for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { + bodies_[static_cast(b)] = + absl::make_unique(graph->op_registry()); + } + + auto find_branch = [&](const Edge* e) { + const auto& id = state_map_->LookupCondId(e->src()); + return IsSwitch(e->src()) ? BranchType(e->src_output()) + : state_map_->FindBranchOf(id, predicate_); + }; + + std::array, 2> stacks; + VLOG(5) << "Merges: " << NodesToString(merges_); + for (Node* m : merges_) { + VLOG(5) << "For merge: " << m->DebugString() << " " + << state_map_->CondStateToString(m); + for (auto e : m->in_edges()) { + if (e->IsControlEdge()) continue; + BranchType branch = find_branch(e); + TF_RET_CHECK(branch == BranchType::kThenBranch || + branch == BranchType::kElseBranch) + << "Error: " << e->src()->name() + << " is not on either then or else branch (" << Branch_Name(branch) + << ") for predicate " << DebugString(predicate_) << " [" + << DebugString(state_map_->LookupCondId(e->src())) << "]."; + Node* src = e->src(); + if (IsSwitch(src)) { + // Switch node outputs and dependencies are handled separately. + TF_RETURN_IF_ERROR(AddSwitch(src)); + } else { + stacks[static_cast(branch)].push_back(src); + } + } + } + + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto output = bodies_[branch_index].get(); + auto& stack = stacks[branch_index]; + VLOG(5) << "In branch: " << Branch_Name(branch) << " " + << NodesToString(stack); + std::vector visited(graph->num_node_ids(), false); + node_maps_[branch_index].resize(graph->num_node_ids(), nullptr); + auto& node_map = node_maps_[branch_index]; + + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + if (visited.at(n->id())) continue; + visited[n->id()] = true; + + // Verify output edges and record control edges exitting scope. + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + if (IsMerge(dst)) continue; + Node* src = e->src(); + + auto dst_id = state_map_->LookupCondId(dst); + auto src_id = state_map_->LookupCondId(src); + if (dst_id != src_id) { + if (e->IsControlEdge()) { + external_control_outputs_.push_back(e->src()); + } else { + // Constants are treated specially to workaround the case of + // non-dominated constant nodes. + if (!IsConstant(src)) { + // TODO(b/78882471): A node that feeds into two different + // CondState is not necessarily an error so log a warning for now + // but revisit to improve the testing to enable making this an + // error. + LOG(WARNING) << errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during out edge testing)"); + } + } + } + } + + // Copying incomming edges to dst node. Iterate over a copy of the edges + // as they could be mutated during iteration. + std::vector in_edges(n->in_edges().begin(), + n->in_edges().end()); + for (const Edge* e : in_edges) { + Node* src = e->src(); + // Skip src/dst node. + if (!src->IsOp()) continue; + + Node* dst = e->dst(); + if (IsSwitch(src)) { + // Switch node outputs and dependencies are handled separately. + TF_RETURN_IF_ERROR(AddSwitch(src)); + continue; + } + + // Verify input is from the same context. + auto src_id = state_map_->LookupCondId(src); + auto dst_id = state_map_->LookupCondId(dst); + if (IsMerge(dst) || src_id == dst_id) { + // TODO(jpienaar): The merge case can be more strict. + if (node_map.at(src->id()) == nullptr) { + node_map.at(src->id()) = output->CopyNode(src); + stack.push_back(src); + } + } else if (e->IsControlEdge()) { + external_control_inputs_.push_back(src); + } else { + // This shouldn't happen, this means we have an external data input + // not entering via a switch node. Work around this by for + // * constant nodes copy them; + // * non-constant nodes, insert a switch along the edge; + if (IsConstant(src)) { + node_map.at(src->id()) = output->CopyNode(src); + } else { + StateMap::CondState state = *dst_id; + state.erase(predicate_); + if (state_map_->GetCondId(state) == src_id) { + TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph)); + continue; + } else { + return errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } + } + } + + Node* src_copy = node_map.at(e->src()->id()); + int src_output = e->src_output(); + if (node_map.at(dst->id()) == nullptr) { + node_map.at(dst->id()) = output->CopyNode(dst); + } + Node* dst_copy = node_map.at(e->dst()->id()); + if (e->IsControlEdge()) { + // Skip control inputs from external context. + if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy); + } else { + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + } + } + + // Build return values from the merge nodes. + int index = 0; + for (Node* m : merges_) { + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto& node_map = node_maps_[branch_index]; + auto output = bodies_[branch_index].get(); + TF_ASSIGN_OR_RETURN(node_map[m->id()], + BuildRetvalNode(output, m->output_type(0), index)); + } + ++index; + + // Connect the input to the merge_ with the retval, except if it is a + // Swich node, which is handled separately. + for (auto e : m->in_edges()) { + if (e->IsControlEdge()) continue; + int branch_index = static_cast(find_branch(e)); + auto& node_map = node_maps_[branch_index]; + auto output = bodies_[branch_index].get(); + Node* in = e->src(); + if (!IsSwitch(in)) { + if (node_map.at(in->id()) == nullptr) { + node_map[in->id()] = output->CopyNode(in); + } + output->AddEdge(node_map[in->id()], e->src_output(), + node_map.at(m->id()), 0); + } + } + } + return Status::OK(); +} + +Status Conditional::BuildIfNode(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(2) << "Build cond function for " << name(); + NodeDefBuilder builder(name(), "If", library); + const string branch_name[] = {"else_branch", "then_branch"}; + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + + NameAttrList body_name; + body_name.set_name( + absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id)); + + VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] + << "): " + << dump_graph::DumpGraphToFile( + "functionalize_cond_body_" + branch_name[branch_index], + *bodies_[branch_index], nullptr); + + FunctionDef body_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index], + body_name.name(), &body_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + builder.Attr(branch_name[branch_index], body_name); + } + + VLOG(3) << "Build input type"; + std::vector inputs; + DataTypeVector in_arg_types; + for (auto& kv : cond_arg_nodes_) { + bool inserted = false; + for (const Node* arg : kv.switches) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + if (!inserted) { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + inserted = true; + } + } + } + } + builder.Attr("Tin", in_arg_types); + + DataTypeVector out_type; + for (const Node* merge : merges_) { + DataType dtype = merge->output_type(0); + out_type.push_back(dtype); + } + builder.Attr("Tout", out_type); + VLOG(3) << "Build output type: " << DataTypeVectorString(out_type); + + builder.Attr("Tcond", DT_BOOL); + builder.Device(predicate_.node->assigned_device_name()); + // Conditional should be the first input ... + builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), + predicate_.index, + predicate_.node->output_type(0))); + // ... followed by the other inputs. + builder.Input(inputs); + + VLOG(3) << "Build If node"; + NodeDef if_def; + TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); + TF_ASSIGN_OR_RETURN(if_node_, + parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); + + return Status::OK(); +} + +Status Conditional::AddInputEdges(Graph* graph) { + VLOG(2) << "AddInputEdges for " << if_node_->name(); + int index = 0; + // Add predicate input. + graph->AddEdge(const_cast(predicate_.node), predicate_.index, if_node_, + index++); + // Add function body inputs. + for (auto& arg : cond_arg_nodes_) { + if (arg.src_output == Graph::kControlSlot) { + graph->AddControlEdge(arg.src, if_node_); + } else { + graph->AddEdge(arg.src, arg.src_output, if_node_, index++); + } + } + for (Node* n : external_control_inputs_) { + graph->AddControlEdge(n, if_node_); + } + return Status::OK(); +} + +Status Conditional::AddOutputEdges(Graph* graph) { + VLOG(2) << "AddOutputEdges for " << if_node_->name(); + int i = 0; + for (Node* node : merges_) { + TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i)); + std::vector edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + if (edge->src_output() > 0) { + return errors::Unimplemented("Output of index (", edge->src_output(), + ") of merge node ", + FormatNodeForError(*node)); + } + + bool control_edge = edge->IsControlEdge(); + graph->RemoveEdge(edge); + if (control_edge) { + graph->AddControlEdge(if_node_, dst); + } else { + graph->AddEdge(if_node_, i, dst, dst_input); + } + } + ++i; + } + for (Node* n : external_control_outputs_) { + graph->AddControlEdge(if_node_, n); + } + + return Status::OK(); +} + +Status Conditional::BuildAndReplace(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "Build If and replace merge nodes " + << NodesToString(this->merges_); + if (replaced_) return Status::OK(); + + TF_RETURN_IF_ERROR(ExtractBodies(graph)); + TF_RETURN_IF_ERROR(BuildArgumentNodes()); + + if (VLOG_IS_ON(3)) { + LOG(INFO) << "Extracted bodies:"; + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto output = bodies_[branch_index].get(); + LOG(INFO) << Branch_Name(branch) << ": " + << DebugString(output->ToGraphDefDebug()); + } + } + + TF_RETURN_IF_ERROR(BuildIfNode(graph, library)); + TF_RETURN_IF_ERROR(AddInputEdges(graph)); + TF_RETURN_IF_ERROR(AddOutputEdges(graph)); + TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); + + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(if_node_, graph->num_node_ids()), + "Converting to If failed."); + + replaced_ = true; + return Status::OK(); +} + +string Conditional::name() const { + CHECK(!merges_.empty()); + return absl::StrCat((*merges_.begin())->name(), "_if"); +} + +Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, + int port) { + Node* id; + TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") + .Input(if_node, port) + .Finalize(graph_, &id)); + state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); + state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); + return Status::OK(); +} + +StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, + const Node* replacee, + const OutputTensor& predicate) { + Status status; + Node* ret = graph_->AddNode(def, &status); + TF_RETURN_IF_ERROR(status); + VLOG(1) << "Adding If for " << replacee->name(); + StateMap::CondId id = state_map_.LookupCondId(replacee); + if (id) { + StateMap::CondState state = *id; + state.erase(predicate); + state_map_.ResetCondId(ret, state_map_.GetCondId(state)); + } else { + state_map_.ResetCondId(ret, nullptr); + } + + state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee)); + + return ret; +} + +Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { + VLOG(2) << "Propagating update state for " << replacee->name() << " " + << state_map_.CondStateToString(replacee); + // Redo topological sort as the order could have changed. + // TODO(jpienaar): The original topological order could also be updated + // dynamically if needed. + std::vector rev_topo_order; + GetPostOrder(*graph_, &rev_topo_order); + + // All the outputs of the new node could potentially be updated. + std::unordered_set changed; + for (auto n : replacee->out_nodes()) + if (n->IsOp()) changed.insert(n); + + // Iterate through the changed/possible changed nodes in topological order. + for (auto it = rev_topo_order.rbegin(); + it != rev_topo_order.rend() && !changed.empty(); ++it) { + if (changed.find(*it) != changed.end()) { + // Update the node state. + Node* n = *it; + StateMap::CondId old_state = state_map_.LookupCondId(n); + state_map_.ResetCondId(n, nullptr); + TF_RETURN_IF_ERROR(DetermineCondState(n)); + if (state_map_.LookupCondId(n) != old_state) { + for (auto out : n->out_nodes()) + if (out->IsOp()) changed.insert(out); + } + changed.erase(n); + } + } + return Status::OK(); +} + +// Returns the most restrictive branch of two branches or neither. This is the +// meet operator of the BranchType lattice. +BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { + if (lhs == rhs) return lhs; + if (lhs == BranchType::kNeither) return rhs; + if (rhs == BranchType::kNeither) return lhs; + if (lhs == BranchType::kBoth) return rhs; + if (rhs == BranchType::kBoth) return lhs; + return BranchType::kNeither; +} + +BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const { + if (IsEmpty(id)) return BranchType::kNeither; + const CondState& nodes = *id; + auto it = nodes.find(predicate); + if (it == nodes.end()) return BranchType::kNeither; + return it->second; +} + +StatusOr FunctionalizeCond::JoinCondStatesNonMerge( + StateMap::CondId src, StateMap::CondId dst) { + VLOG(5) << "Joining src=" << DebugString(src) << " [" << src + << "] and dst=" << DebugString(dst) << " [" << dst << "]"; + + if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst; + + // Nothing to do if the CondState is the same. + if (src == dst) return src; + + StateMap::CondState both = *src; + for (const auto& kv : *dst) { + auto it = both.find(kv.first); + if (it == both.end()) { + both.insert(kv); + } else { + if (it->second != kv.second) { + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + } + } + } + return state_map_.GetCondId(both); +} + +StatusOr FunctionalizeCond::JoinCondStatesMerge( + Node* merge, StateMap::CondId src, StateMap::CondId dst) { + // Determine the flow state when joining two states for a merge + // node. Combining the two states for a merge node is effectively performing a + // disjunction of the states along the different input edges. For a merge that + // can be transformed into a If the two inputs paths have to have a predicate + // on which they differ (e.g., along one edge predicate `p` has to hold while + // on another it should not). This function first determines this predicate + // and then the resultant state is the common path between the two inputs + // followed by s(p, both). + VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " + << DebugString(dst); + if (state_map_.IsEmpty(dst)) return src; + + if (state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst)) return dst; + + std::vector diff; + StateMap::CondState merged; + std::set_symmetric_difference(src->begin(), src->end(), dst->begin(), + dst->end(), std::back_inserter(diff), + CondStateLess()); + std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(), + std::inserter(merged, merged.begin()), CondStateLess()); + + // Update mapping from merge node to predicate. + if (diff.size() == 2) { + auto pred = diff[0].first; + bool different_branches = (diff[0].second != diff[1].second) && + (diff[0].second == BranchType::kThenBranch || + diff[0].second == BranchType::kElseBranch) && + (diff[1].second == BranchType::kThenBranch || + diff[1].second == BranchType::kElseBranch); + if (!(pred == diff[1].first) || !different_branches) + return errors::InvalidArgument( + "Unable to determine predicate for merge node"); + merge_to_predicate_[merge] = pred; + } else { + return errors::InvalidArgument( + "Merge of two inputs that differ on more than one predicate ", + DebugString(src), " and ", DebugString(dst)); + } + + return state_map_.GetCondId(merged); +} + +StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { + Node* src = e->src(); + StateMap::CondId id = state_map_.LookupCondId(e->src()); + + // Dead nodes only propagate dead state. + if (state_map_.IsDead(id)) return id; + + if (IsSwitch(src)) { + StateMap::CondState state; + if (id != nullptr) state = *id; + OutputTensor predicate; + TF_CHECK_OK(GetSwitchPredicate(*src, &predicate)); + if (!e->IsControlEdge()) { + state[predicate] = BranchType(e->src_output()); + } + return state_map_.GetCondId(state); + } + return id; +} + +Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { + // Only Merge nodes with two inputs are supported, but if this is a redundant + // merge, then the dead edge may already have been removed (if due to a + // switch) and so the input count would be incorrect. + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK(); + + int data_inputs = 0; + for (auto e : dst->in_edges()) { + Node* src = e->src(); + VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " + << state_map_.CondStateToString(src); + if (!src->IsOp()) continue; + if (!e->IsControlEdge()) ++data_inputs; + + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); + } + + // Incomplete Merge nodes are not supported. + if (data_inputs != 2) { + return errors::Unimplemented( + dst->name(), " only has ", data_inputs, + " inputs, while only merge nodes with two inputs supported."); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(4) << "Processing forward flow for: " << e->DebugString() << " " + << state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); + } + return Status::OK(); +} + +Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { + // Handle redundant merge nodes. A merge node is considered redundant if + // one input edge is dead while the other has a value. + if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK(); + + const Edge* non_dead_edge = nullptr; + for (auto e : node->in_edges()) { + if (e->IsControlEdge()) continue; + Node* src = e->src(); + + // Handle merge with dead state. + const auto& src_id = state_map_.LookupCondId(src); + if (!state_map_.IsDead(src_id)) { + non_dead_edge = e; + break; + } + } + + if (non_dead_edge == nullptr) { + return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), + " has no non-dead inputs."); + } + state_map_.MarkDead(node); + VLOG(5) << "removing redundant merge: " << node->name(); + while (!node->out_edges().empty()) { + const Edge* oe = *node->out_edges().begin(); + Node* dst_node = oe->dst(); + int dst_port = oe->dst_input(); + graph_->RemoveEdge(oe); + graph_->AddEdge(non_dead_edge->src(), + dst_port == Graph::kControlSlot + ? Graph::kControlSlot + : non_dead_edge->src_output(), + dst_node, dst_port); + } + return Status::OK(); +} + +Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { + // Handle redundant switch nodes. A switch node is considered redundant if + // the predicate of the switch already holds on the current branch. E.g., if + // p is the predicate of the switch but p is already known to hold on this + // branch, then the switch can be removed and the dead state propagated + // along one. The checking of predicate is based on the exact predicate + // (rather than boolean equivalence) and aimed at redundant switches as + // currently generated by gradient code. + StateMap::CondId dst_id = state_map_.LookupCondId(node); + if (state_map_.IsDead(dst_id)) return Status::OK(); + + BranchType b; + OutputTensor pred; + TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); + + // Determine if we are already on a branch where the switch predicate is + // true/false. Consider both the data and predicate to determine if the + // node is redundant (skipping over identity node). + b = state_map_.FindBranchOf(dst_id, pred); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) { + OutputTensor val; + const Edge* e; + TF_RETURN_IF_ERROR(node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + while (IsIdentity(val.node)) { + TF_RETURN_IF_ERROR(val.node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + } + b = state_map_.FindBranchOf(dst_id, val); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + } + + VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " + << DebugString(dst_id); + const Edge* value_edge; + TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); + Node* val_node = value_edge->src(); + int val_port = value_edge->src_output(); + while (!node->out_edges().empty()) { + auto e = *node->out_edges().begin(); + Node* dst_node = e->dst(); + int dst_input = e->dst_input(); + int switch_branch = e->src_output(); + graph_->RemoveEdge(e); + if (switch_branch == Graph::kControlSlot) { + if (IsMerge(dst_node)) { + auto id_or = JoinCondStatesMerge(dst_node, dst_id, + state_map_.LookupCondId(dst_node)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst_node)); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); + } else { + auto id_or = + JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node)); + TF_RETURN_IF_ERROR(id_or.status()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); + } + } else if (BranchType(switch_branch) != b) { + state_map_.MarkDead(dst_node); + continue; + } + graph_->AddEdge( + val_node, + switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port, + dst_node, dst_input); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { + // The state that is propagated along the given edge. + for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { + Node* dst = *it; + TF_RETURN_IF_ERROR(DetermineCondState(dst)); + TF_RETURN_IF_ERROR(DetermineAncestorState(dst)); + if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); + if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); + + VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) + << " @ " << state_map_.AncestorStateToString(dst); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineAncestorState(Node* dst) { + StateMap::AncestorId id = nullptr; + StateMap::AncestorState state; + + auto insert = [&](StateMap::AncestorId id, Node* src) { + auto other_id = state_map_.LookupAncestorId(src); + if (other_id != id && other_id != nullptr) { + state.insert(other_id->begin(), other_id->end()); + } + if (IsSwitch(src) || IsMerge(src)) { + state.insert(src); + } + return state_map_.GetAncestorId(state); + }; + + // Compute the union of all the switch/merge nodes that affects the input of + // dst. + for (auto e : dst->in_edges()) { + Node* src = e->src(); + id = insert(id, src); + } + state_map_.ResetAncestorId(dst, id); + return Status::OK(); +} + +void FunctionalizeCond::DeleteReachableAndDeadNodes( + const std::vector& switch_ids, const std::vector& merge_order) { + // Delete all nodes that have been extracted or are reachable from + // deleted/dead nodes. The input and outgoing edges should have already been + // removed. + std::deque delete_nodes; + std::vector deleted(graph_->num_node_ids(), false); + // Don't try to delete source or sink nodes. + deleted[graph_->kSourceId] = true; + deleted[graph_->kSinkId] = true; + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) { + Node* s = graph_->FindNodeId(s_id); + if (s == nullptr) continue; + for (const Edge* e : s->out_edges()) { + // Control outputs of switch nodes (which are unconditionally executed if + // the switch is) are not removed as they need not be part of a + // conditional. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[s_id] = true; + graph_->RemoveNode(s); + } + + // All merge nodes should have been transformed at this point and we remove + // them from the graph here. + for (Node* m : merge_order) { + for (const Edge* e : m->out_edges()) { + // Similar to control outputs of switch nodes don't remove control + // outputs of merge nodes. + // TODO(jpienaar): Check cases where output edges still exist here vs + // being removed in AddOutputEdges. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[m->id()] = true; + graph_->RemoveNode(m); + } + + // Enqueue all the dead nodes. + for (Node* n : graph_->nodes()) { + if (state_map_.IsDead(state_map_.LookupCondId(n))) { + delete_nodes.push_back(n->id()); + } + } + + while (!delete_nodes.empty()) { + int d_id = delete_nodes.front(); + delete_nodes.pop_front(); + if (deleted[d_id]) continue; + Node* d = graph_->FindNodeId(d_id); + // Switch and Merge nodes could have been deleted already. + if (d == nullptr) continue; + for (const Edge* e : d->out_edges()) { + delete_nodes.push_back(e->dst()->id()); + } + deleted[d_id] = true; + graph_->RemoveNode(d); + } +} + +void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { + // Sort merge nodes by nesting depth. + using sort_pair = std::pair; + std::vector inner_to_outer_merge_order; + inner_to_outer_merge_order.reserve(merge_order->size()); + for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { + Node* merge = *it; + StateMap::CondId id = state_map_.LookupCondId(merge); + int depth = id != nullptr ? id->size() : 0; + inner_to_outer_merge_order.emplace_back(depth, merge); + } + std::stable_sort( + inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(), + [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; }); + merge_order->clear(); + for (sort_pair t : inner_to_outer_merge_order) { + merge_order->push_back(t.second); + } +} + +Status FunctionalizeCond::FunctionalizeInternal() { + // The general approach for converting a tf.cond (as lowered via switch/merge + // nodes) to a functional if is as follows: + // 1. Determine the topological order and collect all the switch and merge + // nodes in the graph; + // 2. Compute the predicates and dominance structure for all the nodes in the + // graph - this includes which predicate must be true for a op to execute + // (predicate values are considered directly rather than attempting to + // determine deeper equivalence). We shall refer to this structure as the + // CondState; + // 3. Sort the merge nodes by nesting depth; + // 4. Extract merge nodes together that have the same CondState and + // AncestorState from the innermost to the outermost into IfOps; + // Note: In the above only nodes that feed into a merge node will be + // considered for functionalization. + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Record reverse topological for merge and switch nodes; + std::vector rev_topo_order; + std::vector switch_ids; + std::vector merge_order; + DFS(*graph_, nullptr, [&](Node* n) { + if (IsSwitch(n)) { + switch_ids.push_back(n->id()); + } + if (IsMerge(n)) { + merge_order.push_back(n); + } + if (n->IsOp()) { + rev_topo_order.push_back(n); + } + }); + + // No merges to functionalize. + if (merge_order.empty()) { + // No merges mean no switch values consumed (as only considering values + // fetchable as output of merge); + for (auto it = switch_ids.begin(); it != switch_ids.end(); ++it) { + graph_->RemoveNode(graph_->FindNodeId(*it)); + } + return Status::OK(); + } + + TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); + if (VLOG_IS_ON(4)) DumpGraphWithCondState("id"); + + // Sort the merge nodes from innermost outwards. + SortMergeNodes(&merge_order); + + // Cluster merge nodes by CondId and AncestorId in order of nesting. + using ClusterPair = std::pair; + std::deque> merge_clusters; + std::map merge_cluster_index; + for (Node* merge : merge_order) { + auto cond_id = state_map_.LookupCondId(merge); + if (state_map_.IsDead(cond_id)) continue; + + ClusterPair key = + std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto idx = merge_cluster_index.find(key); + if (idx == merge_cluster_index.end()) { + merge_cluster_index[key] = merge_clusters.size(); + merge_clusters.push_back({merge}); + } else { + merge_clusters[idx->second].emplace_back(merge); + } + } + + // Extract the conditionals from inner most to outer most. Extracting from + // innermost to outermost enables the extraction pass to stop once it + // encounters a Switch node instead of having to keep track of Switch/Merge + // nodes seen. + for (const auto& cluster : merge_clusters) { + // Construct a Conditional with the predicate of the merge. + Conditional cond(merge_to_predicate_.at(cluster.front()), this, + &state_map_); + for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); + TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); + + if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); + } + + DeleteReachableAndDeadNodes(switch_ids, merge_order); + + return Status::OK(); +} + +void FunctionalizeCond::DumpGraphWithCondState(const string& name) { + const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup"; + + for (Node* n : graph_->nodes()) { + n->ClearAttr(kCondGroupDebugAttr); + n->AddAttr(kCondGroupDebugAttr, + absl::StrCat(state_map_.CondStateToString(n), "_", + state_map_.AncestorStateToString(n))); + } + LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " + << dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_cond_", name), *graph_, + library_); +} + +Status FunctionalizeCond::Functionalize(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "FunctionalizeCond::Functionalize"; + FunctionalizeCond fc(graph, library); + return fc.FunctionalizeInternal(); +} + +} // namespace functionalize_cond + +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) { + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + return functionalize_cond::FunctionalizeCond::Functionalize(graph, library); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h new file mode 100644 index 0000000000000000000000000000000000000000..189980894073b1da1a12d1c284536336eb920900 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -0,0 +1,243 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Functionalize all the switch-merge nodes of a loop-free graph into If +// nodes. That is, attempt to transform every remaining switch and merge nodes +// in the graph into If nodes. +// Precondition: All while loops have been removed from graph. +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + +// Internal functions/classes exposed for testing purposes. +namespace functionalize_cond { + +// All nodes are assumed to be either in no branch, then branch, else branch, +// or both branches (such as merge nodes). +// The code below relies on Else and Then being 0 and 1 (corresponding to the +// switch outputs). Both and Neither are arbitrary. +enum class BranchType { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, +}; + +// StateMap is responsible for mapping from each graph Node to +// * a CondState, where each CondState is a map from predicate to branch (i,e., +// what predicates have to hold or not hold). +// * a AncestorState, where each AncestorState is a set of switch/merge nodes +// that are an ancestor of the node in the graph; +// For efficiency, this class interns the CondState (AncestorState), so that +// CondState (AncestorState) equality comparisons are simply pointer +// comparisons. +class StateMap { + public: + explicit StateMap(Graph* graph); + + // Compare two OutputTensors by (node id, index). + struct OutputTensorLess { + bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; + }; + + // A node in the graph is executed when multiple conditions hold. Keep track + // of the predicates that must hold for a node to execute. + using CondState = std::map; + + // Every unique ID is mapped to a CondState. + using CondId = const CondState*; + + // Keep track of which switch/merge node's feed into a node's values. + using AncestorState = std::set; + + // Every unique ID is mapped to a AncestorState. + using AncestorId = const AncestorState*; + + // Returns the CondId for a given node. + CondId LookupCondId(const Node* node) const; + + // Returns the unique CondId for CondState. + CondId GetCondId(const CondState& state); + + // Resets the CondId for a given node. + void ResetCondId(const Node* node, CondId id); + + // Returns the AncestorId for a given node. + AncestorId LookupAncestorId(const Node* node) const; + + // Returns the unique AncestorId for CondState. + AncestorId GetAncestorId(const AncestorState& state); + + // Resets the AncestorId for a given node. + void ResetAncestorId(const Node* node, AncestorId id); + + // Marks `node` as dead. + void MarkDead(const Node* node); + + // Determine branch execution of CondState. + BranchType FindBranchOf(CondId id, OutputTensor predicate) const; + + // Returns textual representation of node's CondState. + string CondStateToString(const Node* node) const; + string CondStateToString(CondId id) const; + + // Returns textual representation of node's AncestorState. + string AncestorStateToString(const Node* node) const; + + // Returns whether the cond state is the dead state. + bool IsDead(CondId id) const; + + // Returns whether the cond state is the empty state. + bool IsEmpty(CondId id) const; + + private: + // Hash for CondState and AncestorState. + struct Hash { + size_t operator()(const CondState& map) const; + size_t operator()(const AncestorState& map) const; + }; + + // Set to keep track of unique CondStates. + // Pointers to the entries in the unordered set are used as identifiers: + // unordered_set guarantees that the pointers remain the same. + std::unordered_set condstate_set_; + + // Mapping from Node id to CondId. + std::vector node_to_condid_map_; + + // Track the CondId for newly inserted nodes. We use a vector to quickly map + // from Node id in the original graph to the CondId, but there will be nodes + // added to the original graph (such as If nodes) whose CondState needs to be + // tracked too. + std::unordered_map added_node_condid_mapping_; + + // AncestorId variants of the CondId members. + std::unordered_set ancestorstate_set_; + std::vector node_to_ancestorid_map_; + std::unordered_map added_node_ancestorid_mapping_; + + // Identifier of the dead flow state. The empty flow state is represented with + // a nullptr. + CondId dead_id_; +}; + +// FunctionalizeCond groups all the state used by functionalizing conditionals +// of the given graph together. +class FunctionalizeCond { + public: + // Functionalize all the switch-merge nodes of a loop-free graph into If + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into If nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + + // Build identity node with the same name as the merge that will be replaced + // in case the output is fetched/colocated. + Status AddIdentityNode(const Node* replacee, Node* if_node, int port); + + // Add a If node to the graph defined by def that will, amongst other, replace + // replacee in the graph. + xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); + + // Propagates the state of a newly inserted node. + Status PropagateUpdatedState(const Node* replacee); + + // Dump graph with the CondState annotated. + void DumpGraphWithCondState(const string& name); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + + // Performs the actual cond functionalization. Iterate over groups of merge + // nodes (linked by common predicates & ancestor IDs), from innermost to + // outermost, and extract into If nodes. + Status FunctionalizeInternal(); + + // Returns the forward flow state propagated along edge `e`. + // This may modify state_map_. + StateMap::CondId StateAlongEdge(const Edge* e); + + // Determines the CondState and AncestorState of all the nodes in the given + // vector where the input is expected in reverse topological order. + // This populates the state_map_. + Status DetermineStates(std::vector rev_topo_order); + + // Determine the CondState for a given node using the incomming edges + // to the node. Note: it is expected that this node's CondState is only + // determined once its input's CondState is. + Status DetermineCondState(Node* dst) { + if (IsMerge(dst)) return DetermineCondStateMerge(dst); + return DetermineCondStateNonMerge(dst); + } + + // Helper functions for DetermineCondState. + Status DetermineCondStateNonMerge(Node* dst); + Status DetermineCondStateMerge(Node* dst); + + // Determines the dst node's CondState by joining the src and dst's CondState + // where either the dst node is a merge or not. + // These may modify state_map_. + xla::StatusOr JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); + + // Determines which switch/merge nodes are ancestors of this node. + Status DetermineAncestorState(Node* dst); + + // Checks if a merge node is redundant and if so removes it from the graph. + Status RemoveRedundantMerge(Node* node); + + // Checks if a switch node is redundant and if so removes it from the graph. + Status RemoveRedundantSwitch(Node* node); + + // Sorts merge nodes (in reverse topological order) in order of increasing + // nesting depth. + void SortMergeNodes(std::vector* merge_order); + + // Deletes all nodes in/consumers reachable from switch/merge nodes that were + // extracted. + void DeleteReachableAndDeadNodes(const std::vector& switch_ids, + const std::vector& merge_order); + + // Member used to unique the CondState to a unique CondId (AncestorState to a + // unique AncestorId) and keep track of CondState/CondId + // (AncestorState/AncestorId) per Node. + StateMap state_map_; + + // Mapping from merge nodes to predicate. + std::unordered_map merge_to_predicate_; + + FunctionLibraryDefinition* library_; + Graph* graph_; + + friend class FunctionalizeCondTest; +}; + +} // namespace functionalize_cond + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0aabd63bbda784b3b7103a438ce025eea0cd93b --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for the backward const analysis. + +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace functionalize_cond { + +class FunctionalizeCondTest : public ::testing::Test { + protected: + FunctionalizeCondTest() { + graph_.reset(new Graph(OpRegistry::Global())); + flib_def_.reset( + new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_)); + fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(), + flib_def_.get())); + } + + StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) { + return fc_->state_map_.GetCondId(state); + } + + string GetString(const StateMap::StateMap::CondId id) { + return fc_->state_map_.CondStateToString(id); + } + + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); + } + + xla::StatusOr JoinCondStatesMerge(Node* n, + StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesMerge(n, src, dst); + } + + FunctionDefLibrary fdef_lib_; + std::unique_ptr fc_; + std::unique_ptr flib_def_; + std::unique_ptr graph_; +}; + +namespace { + +TEST_F(FunctionalizeCondTest, JoinCondStates) { + Tensor pred_tensor(DT_BOOL, TensorShape()); + pred_tensor.flat().setZero(); + Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); + Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* m = test::graph::Merge(graph_.get(), val, val); + + StateMap::CondId then_branch; + { + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch)); + then_branch = GetUniqueId(ss); + } + StateMap::CondId else_branch; + { + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch)); + else_branch = GetUniqueId(ss); + } + + // An non-merge op with inputs from then and else branch. + Status status = JoinCondStatesNonMerge(then_branch, else_branch).status(); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + + // Merge between then and else branch. + auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); + TF_EXPECT_OK(joined_or.status()); + StateMap::CondId joined = joined_or.ValueOrDie(); + + // Merge between then branch and both branch. + auto t = JoinCondStatesNonMerge(then_branch, joined); + // Note: this is OK in terms of constraint predication, but + TF_EXPECT_OK(t.status()); +} + +} // namespace +} // namespace functionalize_cond +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 0904778f97c95628c81054cd4bc2ff32ff440a33..2d45507796a39d029665c709bb6ad27e7697d544 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -21,1433 +21,53 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { -namespace { - -using xla::StatusOr; - -const char* const kArgOp = "_Arg"; -const char* const kRetValOp = "_Retval"; - -// Information about a loop argument. -struct Arg { - // Every loop argument has an Enter node. - Node* enter; - - // Is the loop argument a loop-invariant value? Taken from the `is_constant` - // attribute on the Enter node. - bool is_loop_invariant; - - // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant - // arguments must have all of the following nodes: - Node* merge = nullptr; - Node* switch_node = nullptr; - Node* next_iteration = nullptr; - Node* exit = nullptr; -}; - -// Information about a loop frame. -struct Frame { - string name; - - // Pointer to the parent frame. The root frame has a pointer to itself. - Frame* parent = nullptr; - int num_children = 0; - - // Arguments to this loop. - std::vector args; - - // The loop condition of the loop. There should be exactly one loop condition - // in every loop. - Node* loop_cond = nullptr; - - // Set of nodes that belong to the loop frame. - std::unordered_set nodes; -}; - -// Comparison function used for sorting nodes consistently. -// a) resource variables are last, and -// b) sort lexicographically by name (for deterministic output). -struct NodeCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } -}; - -// Returns a textual representation of the names of the nodes in the input. -template -string NodesToString(const T& nodes) { - return strings::StrCat("{", - str_util::Join(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), - "}"); -} - -// Copies a subgraph from `graph` to `output` by performing a reverse DFS -// starting at nodes in vector `stack`. -// `node_map` is a vector indexed by source node ID to dest nodes. -// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` -// before the traversal clients can cut the graph. If a frame is provided (frame -// != nullptr), then this functions will return an error if the -// traversal leaves 'frame'; the client must add enough nodes to `node_map` to -// cut the graph and prevent the traversal from escaping. -// -// `squash_src_outputs` contains a bool for each source node ID. If true, then -// the source output on that node will be replaced by zero when copied. This is -// used when replacing a Switch node with an _Arg node. The output we are -// taking from the Switch node was not necessarily the first output, but _Arg -// nodes only have one output. By adding the Switch node to `squash_src_outputs` -// we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame* frame, - std::vector stack, - const std::vector& squash_src_outputs, - std::vector* node_map, Graph* output) { - VLOG(3) << "Stack: " << NodesToString(stack); - std::vector visited(graph.num_node_ids(), false); - while (!stack.empty()) { - Node* n = stack.back(); - stack.pop_back(); - - VLOG(5) << "Copying node " << n->name(); - - if (visited[n->id()]) continue; - visited[n->id()] = true; - - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { - // We traversed out of the loop frame, without encountering a cut node. - return errors::Internal("Graph traversal of loop frame ", frame->name, - " escaped frame at ", src->name(), - " without encountering an argument node."); - } - if ((*node_map)[src->id()] == nullptr) { - (*node_map)[src->id()] = output->CopyNode(src); - stack.push_back(src); - } - Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() - ? 0 - : e->src_output(); - Node* dst_copy = (*node_map)[e->dst()->id()]; - output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); - } - } - return Status::OK(); -} - -StatusOr AddNode(const NodeDef& node_def, Graph* graph) { - Status status; - Node* inserted_node = graph->AddNode(node_def, &status); - if (!status.ok()) { - return status; - } - return inserted_node; -} - -// Check that the graph has no cycle containing the given node. -Status CheckNoCycleContains(const Node* node, const int num_nodes) { - std::vector ready; - ready.push_back(node); - std::vector visited(num_nodes); - while (!ready.empty()) { - const Node* current_node = ready.back(); - ready.pop_back(); - visited[current_node->id()] = true; - for (const Edge* out : current_node->out_edges()) { - if (out->dst() == node) { - return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), - "(", node->def().op(), ") feeds into itself."); - } else if (!visited[out->dst()->id()]) { - ready.push_back(out->dst()); - } - } - } - return Status::OK(); -} - -StatusOr BuildArgNode(Graph* graph, DataType type, int index) { - NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); - builder.Attr("T", type); - builder.Attr("index", index); - TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); - return AddNode(arg_def, graph); -} - -StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { - NodeDef ret_def; - ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(kRetValOp, index)); - AddNodeAttr("T", type, &ret_def); - AddNodeAttr("index", index, &ret_def); - return AddNode(ret_def, graph); -} - -// Builds a graph for the loop condition. -Status BuildLoopCondition(const Graph& graph, Frame* frame, - std::unique_ptr* cond_output) { - VLOG(2) << "Building loop condition for " << frame->name; - *cond_output = xla::MakeUnique(graph.op_registry()); - Graph* output = cond_output->get(); - - // Map from nodes in the original graph to the condition graph. - std::vector node_map(graph.num_node_ids(), nullptr); - std::vector squash_src_outputs(graph.num_node_ids(), false); - - // Build one _Arg node for each Enter node. - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - - TF_ASSIGN_OR_RETURN(Node * arg_node, - BuildArgNode(output, arg.enter->input_type(0), i)); - if (arg.is_loop_invariant) { - node_map[arg.enter->id()] = arg_node; - } else { - node_map[arg.merge->id()] = arg_node; - } - } - - // Build a Retval node for the loop condition. The LoopCond nodes are always - // boolean because of the type constraints on the LoopCond op. - TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], - BuildRetvalNode(output, DT_BOOL, 0)); - - // Performs a reverse DFS, copying nodes and edges to the output graph. - // The _Arg and _Retval nodes were added unconditionally above, so we are - // guaranteed to get the correct function signature. - return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, - &node_map, output); -} - -// Builds a graph for the loop body. -Status BuildLoopBody(const Graph& graph, Frame* frame, - DataTypeVector* arg_types, - std::unique_ptr* body_output) { - VLOG(2) << "Building loop body for " << frame->name; - *body_output = xla::MakeUnique(graph.op_registry()); - Graph* output = body_output->get(); - - // Map from nodes in the original graph to the condition graph. - std::vector node_map(graph.num_node_ids(), nullptr); - std::vector squash_src_outputs(graph.num_node_ids(), false); - - // Build one _Arg node for each Enter node. - std::vector next_iterations; - next_iterations.reserve(frame->args.size()); - arg_types->reserve(frame->args.size()); - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - - DataType dtype = arg.enter->input_type(0); - arg_types->push_back(dtype); - - TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); - - if (dtype == DT_RESOURCE) { - // The convention of the XLA bridge is that resource variable arguments - // are only inputs to the loop body and have no corresponding output. - // TODO(b/37741920): change the convention so that DT_RESOURCE variables - // are both inputs and outputs, and then remove this case. - TF_RET_CHECK(arg.is_loop_invariant); - node_map[arg.enter->id()] = arg_node; - } else { - TF_ASSIGN_OR_RETURN(Node * retval_node, - BuildRetvalNode(output, dtype, i)); - - if (arg.is_loop_invariant) { - // Argument is loop-invariant. Forward it from the Arg to the Retval. - node_map[arg.enter->id()] = arg_node; - output->AddEdge(arg_node, 0, retval_node, 0); - } else { - // Argument is loop-varying. - node_map[arg.switch_node->id()] = arg_node; - // The Switch node has two outputs, but _Arg only has one. This tells - // the CopySubgraph function to rewrite the output number of edges from - // the _Arg node to be 0 rather than copying the output number from the - // Switch node. - squash_src_outputs[arg.switch_node->id()] = true; - node_map[arg.next_iteration->id()] = retval_node; - next_iterations.push_back(arg.next_iteration); - } - } - } - - // Performs a reverse DFS, copying nodes and edges to the output graph. - // The _Arg and _Retval nodes were added unconditionally above, so we are - // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), - squash_src_outputs, &node_map, output)); - - return Status::OK(); -} - -// Copy the FunctionDef of given function from lookup_library to library, if -// it can be found in lookup_library but is missing from library. -Status AddMissingFunctionByName(const string& function_name, - const FunctionLibraryDefinition* lookup_library, +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, FunctionLibraryDefinition* library) { - if (!library->Find(function_name) && lookup_library->Find(function_name)) { - return library->AddFunctionDef(*lookup_library->Find(function_name)); - } - return Status::OK(); -} - -// Iterate over all functions that the given fdef refers to. Copy the missing -// FunctionDefs from lookup_library to library. -Status AddMissingFunctionDef(const FunctionDef& fdef, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - TF_RET_CHECK(lookup_library); - for (const NodeDef& node : fdef.node_def()) { - if (library->Find(node.op())) { - continue; - } - // The function referred by 'SymbolicGradient' node is specified in its - // attribute 'f'. - if (node.op() == FunctionLibraryDefinition::kGradientOp) { - const AttrValue* attr = - AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); - if (!attr) { - return errors::InvalidArgument("SymbolicGradient is missing attr: f"); - } - const string& func_name = attr->func().name(); - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(func_name, lookup_library, library)); - // Copy the user-defined gradient function if it exists. - const string grad_name = lookup_library->FindGradient(func_name); - if (!grad_name.empty() && library->FindGradient(func_name).empty()) { - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(grad_name, lookup_library, library)); - GradientDef grad_def; - grad_def.set_function_name(func_name); - grad_def.set_gradient_func(grad_name); - TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); - } - } else if (lookup_library->Find(node.op())) { - TF_RETURN_IF_ERROR( - library->AddFunctionDef(*lookup_library->Find(node.op()))); - } - } - return Status::OK(); -} - -Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, Frame* frame, - FunctionLibraryDefinition* library) { - VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph, + VLOG(2) << "FunctionalizeControlFlow (initial): " + << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); - // Split loop-varying Enter nodes with multiple successors. If the same - // Tensor is fed as input to multiple loop arguments, we may end up with a - // shared Enter node. We clone Enter nodes with multiple successors to - // maintain the invariant of a unique Enter node per argument of the final - // loop. - std::vector args; - for (const Arg& arg : frame->args) { - if (arg.is_loop_invariant) { - args.push_back(arg); - } else { - std::vector edges(arg.enter->out_edges().begin(), - arg.enter->out_edges().end()); - for (int i = 0; i < edges.size(); ++i) { - if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { - continue; - } - TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); - Arg new_arg; - new_arg.is_loop_invariant = false; - if (i == 0) { - new_arg.enter = arg.enter; - } else { - new_arg.enter = graph->CopyNode(arg.enter); - frame->nodes.insert(new_arg.enter); - for (Edge const* e : arg.enter->in_edges()) { - graph->AddEdge(e->src(), e->src_output(), new_arg.enter, - e->IsControlEdge() ? Graph::kControlSlot : 0); - } - Node* dst = edges[i]->dst(); - int dst_input = edges[i]->dst_input(); - graph->RemoveEdge(edges[i]); - graph->AddEdge(new_arg.enter, 0, dst, dst_input); - } - args.push_back(new_arg); - } - } - } - frame->args = std::move(args); - - std::sort( - frame->args.begin(), frame->args.end(), - [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); }); + // Functionalize and remove while loops from graph. + TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library)); - if (frame->loop_cond == nullptr) { - return errors::InvalidArgument("Loop ", frame->name, - " has no LoopCond node"); - } - - // Find the set of Switch nodes that are successors of the LoopCond. - std::unordered_set switches; - for (const Edge* edge : frame->loop_cond->out_edges()) { - if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && - edge->dst_input() == 1) { - switches.insert(edge->dst()); - } - } - - // For each non-constant argument, looks for the following pattern of nodes: - // Enter ----> Merge --------> Switch --> Exit - // ^ ^ - // | | - // NextIteration LoopCond - // ^ ^ - // | | - // ... ... - for (Arg& arg : frame->args) { - if (!arg.is_loop_invariant) { - // Follow the edge from the Enter to Merge. - const Edge* enter_merge = nullptr; - for (const Edge* e : arg.enter->out_edges()) { - // Ignore control-edges to the sink node. These are allowed by the - // graph invariants, although probably they should have been stripped - // off earlier. - if (e->IsControlEdge() && e->dst()->IsSink()) { - continue; - } - if (enter_merge != nullptr) { - return errors::Internal("Enter node for loop-varying argument ", - FormatNodeForError(*arg.enter), - " has multiple successors: ", - FormatNodeForError(*enter_merge->dst()), - " and ", FormatNodeForError(*e->dst())); - } - enter_merge = e; - } - if (enter_merge == nullptr) { - return errors::Internal("Enter node for loop-varying argument ", - FormatNodeForError(*arg.enter), - " has zero successors"); - } - arg.merge = enter_merge->dst(); - if (!IsMerge(arg.merge)) { - return errors::InvalidArgument( - "Successor of Enter node for loop-varying argument ", - FormatNodeForError(*arg.merge), - " is not a Merge node; got: ", arg.merge->type_string()); - } - - // Find the NextIteration from the merge. There should be two inputs to - // the Merge and the NextIteration should be the other input. - if (arg.merge->input_types().size() != 2) { - return errors::InvalidArgument( - "Unexpected number of inputs to Merge node for loop-varying " - "argument ", - FormatNodeForError(*arg.merge), "; expected 2, got ", - arg.merge->input_types().size()); - } - TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), - &arg.next_iteration)); - if (!IsNextIteration(arg.next_iteration)) { - return errors::InvalidArgument( - "Expected NextIteration node as input to Merge node; got node ", - FormatNodeForError(*arg.next_iteration), " with kind ", - arg.next_iteration->type_string()); - } - - // Find the Switch successor of the Merge. There should be exactly one - // Switch node that is a successor of both the Merge and the LoopCond. - for (const Edge* edge : arg.merge->out_edges()) { - if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && - switches.find(edge->dst()) != switches.end()) { - if (arg.switch_node != nullptr) { - return errors::InvalidArgument("Duplicate Switch successors to ", - FormatNodeForError(*arg.merge)); - } - arg.switch_node = edge->dst(); - } - } - if (arg.switch_node == nullptr) { - return errors::InvalidArgument("Missing Switch successor to ", - FormatNodeForError(*arg.merge)); - } - - // Update the device on the Identity outputs of the switch to match their - // target. These Identity outputs do not - - // Loop over the switch node's output to: - // - Find the Exit successor. - // - Set the sharding on all Identity outputs of the switch. These - // identity nodes are values used by the loop body or condition. - // The Identity node may have the wrong device so copy the device from - // one of its outputs instead. - std::deque possible_exit; - for (const Edge* edge : arg.switch_node->out_edges()) { - if (edge->src_output() == 0) { - possible_exit.push_back(edge); - } - if (IsIdentity(edge->dst())) { - TF_RETURN_IF_ERROR( - SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); - } - } - // TODO(b/67425339): Allow general graph between switch and exit. - while (!possible_exit.empty()) { - const Edge* edge = possible_exit.front(); - possible_exit.pop_front(); - if (IsExit(edge->dst())) { - if (arg.exit != nullptr) { - return errors::InvalidArgument( - "Duplicate Exit successors to ", - FormatNodeForError(*arg.switch_node)); - } - arg.exit = edge->dst(); - } else { - if (!IsIdentity(edge->dst())) { - return errors::Unimplemented("General graph between switch (", - FormatNodeForError(*arg.switch_node), - ") and exit node of frame ", - frame->name, " not supported yet."); - } - for (const Edge* out : edge->dst()->out_edges()) { - possible_exit.push_back(out); - } - } - } - } - } - - // Builds the condition and body functions. - std::unique_ptr cond_graph; - TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); - DataTypeVector arg_types; - std::unique_ptr body_graph; - TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); - - VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) - << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); - - static std::atomic sequence_num(0LL); - int64 id = ++sequence_num; - NameAttrList cond_name; - cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); - NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_body_", id)); - FunctionDef cond_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); - FunctionDef body_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); - - TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); - if (lookup_library) { - // Copy missing FunctionDefs from lookup_library to library to make library - // self-contained. - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(cond_fdef, lookup_library, library)); - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(body_fdef, lookup_library, library)); - } - - // Builds a While operator. - NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); - builder.Attr("T", arg_types); - builder.Attr("cond", cond_name); - builder.Attr("body", body_name); - std::vector inputs; - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - inputs.push_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), arg_types[i])); - } - } - builder.Input(inputs); - TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); - TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); - - // Copies edges to the Enter nodes and from the Exit nodes onto the While. - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph->AddControlEdge(in_edge->src(), while_node); - } else { - graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); - } - - if (!arg.is_loop_invariant) { - // Add output edges if the output of the loop is consumed. - if (arg.exit != nullptr) { - std::vector edges(arg.exit->out_edges().begin(), - arg.exit->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (dst_input == Graph::kControlSlot) { - graph->AddControlEdge(while_node, dst); - } else { - graph->AddEdge(while_node, i, dst, dst_input); - } - } - } - } - } - - // Remove the old nodes from the graph, and add the while node to the parent - // frame. - for (Node* node : frame->nodes) { - graph->RemoveNode(node); - } - frame->nodes.clear(); - frame->parent->nodes.insert(while_node); + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); - VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph, + VLOG(2) << "FunctionalizeControlFlow (final): " + << dump_graph::DumpGraphToFile("functionalize_final", *graph, library); return Status::OK(); } -class FunctionalizeCond { - public: - // All nodes are assumed to be either in no branch, then branch, else branch, - // or both branches (such as merge nodes). - enum Branch { - kElseBranch = 0, - kThenBranch = 1, - kBoth = 2, - kNeither = 3, - kNumBranchTypes = 4 - }; - - // Returns a textual representation of the Branch b. - static string Branch_Name(FunctionalizeCond::Branch b); - - // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf - // nodes. That is, attempt to transform every remaining switch and merge nodes - // in the graph into XlaIf nodes. - // Precondition: All while loops have been removed from graph. - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: - // CondArgNode represents a input to the conditional and its corresponding - // switch nodes. - struct CondArgNode { - explicit CondArgNode(Node* src, int src_output) - : src(src), src_output(src_output) {} - string ToString() const { - return strings::StrCat("src=", src->name(), ":", src_output, - " switches=", NodesToString(switches)); - } - - Node* src; - int src_output; - std::vector switches; - }; - using CondArgNodes = std::vector; - - struct ForwardFlowNode { - explicit ForwardFlowNode(Branch branch = Branch::kNeither) - : branch(branch), count(0) {} - string ToString() const { - return strings::StrCat("branch=", Branch_Name(branch), " count=", count); - } - Branch branch; - int count; - }; - - // Group of switch nodes that will be part of the same XlaIf. - struct SwitchCluster { - explicit SwitchCluster(const Edge* predicate_edge) - : predicate_edge(predicate_edge) {} - string ToString() const { - return strings::StrCat(name, " predicate=", predicate_edge->src()->name(), - " switches=", NodesToString(switches)); - } - - string name; - const Edge* predicate_edge; - std::vector switches; - }; - - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, - bool dump_graphs) - : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} - - // Perform the actual cond functionalization. Iterate over groups of switch - // nodes (linked by common predicate), from innermost to outermost, and - // extract into XlaIf nodes. - Status FunctionalizeInternal(); - - // Determines the branch_map (mapping from node to branch of cond) and - // frontier (the nodes where the cond ends). - StatusOr, - std::unordered_set>> - DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); - - // Returns XlaIf node created from subgraph of merge and switch nodes. This - // encapsulates the process of extracting the bodies needed for the then and - // else branch, creates a XlaIf node, removing the nodes of the branches from - // the graph and replacing the merge node with a XlaIf. - StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, - const SwitchCluster& switch_cluster, - const std::vector& switches); - - // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. - StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, - const SwitchCluster& switch_cluster, - const std::vector& merge_nodes); - - // Extracts a function body corresponding to the given input edge of the merge - // node. - Status ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switches, - const std::vector& merge_nodes, int input_edge, - Graph* body); - - // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgNodes& cond_arg_nodes, - const Edge* predicate_edge, Node* if_node); - - // Adds all output edges from the `if_node`. - Status AddOutputEdges(const std::vector& outputs, Node* if_node); - - // Returns the switch clusters of graph_ in postorder. Dead switch nodes are - // skipped and removed from the graph. - StatusOr> DeterminePredicateSwitchOrder(); - - // Update the state for destination based on the state of source and the node - // being updated. - Status Join(const ForwardFlowNode& src_state, const Node* dst, - ForwardFlowNode* dst_state); - - // Ensure that all nodes in the branch_map are dominated by the switch - // nodes. Returns nodes that are not dominated by the switches but are a - // control dependency of a node in the cond, and remove such control - // dependencies. - StatusOr> EnsureDominanceAndReturnNonDominatedControlNodes( - const std::unordered_map& branch_map, - const std::vector& switches); - - // Validates that the frontier of nodes for the conditional - // section are as expected. - Status ValidateFrontier( - const std::unordered_map& branch_map, - const std::unordered_set& frontier); - - FunctionLibraryDefinition* library_; - Graph* graph_; - bool dump_graphs_; -}; - -bool IsDeadSwitch(const Node* node) { - for (const Edge* e : node->out_edges()) { - const Node* dst = e->dst(); - if (!dst->IsIdentity()) { - return false; - } - for (const Edge* ee : dst->out_edges()) { - if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { - return false; - } - } - } - return true; -} - -string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { - const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { - "else", "then", "both", "neither", "count"}; - return branch_name[b]; -} - -Status FunctionalizeCond::ValidateFrontier( - const std::unordered_map& - branch_map, - const std::unordered_set& frontier) { - std::unordered_set pending[kNumBranchTypes]; - for (Node* n : frontier) { - pending[branch_map.at(n).branch].insert(n); - } - TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); - for (const Node* n : pending[kBoth]) { - TF_RET_CHECK(IsMerge(n)) << n->DebugString(); - // Merge nodes may be in then or else branch too - } - int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) - ? kThenBranch - : kElseBranch; - int other = 1 - index; - for (const Node* n : pending[index]) { - if (pending[other].find(n) != pending[other].end()) { - return errors::Internal( - "Node (", n->DebugString().c_str(), - ") in both Else and Then branch should be in Both."); - } - } - // An empty frontier indicates a dead switch. Above we attempt to remove dead - // switch nodes, but not all are removed so don't treat it as an error yet. - // TODO(jpienaar): Find out why dead switch nodes remain. - // if (pending[kBoth].empty() && pending[kThenBranch].empty() && - // pending[kElseBranch].empty()) { - // return errors::Internal("Unexpected empty frontier for switch nodes"); - // } - return Status::OK(); -} - -Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, - const Node* dst, ForwardFlowNode* dst_state) { - TF_RET_CHECK(dst_state->branch != Branch::kBoth && - dst_state->branch != Branch::kNumBranchTypes) - << "Unexpected/Invalid branch type: Merging " - << Branch_Name(src_state.branch) << " with " - << Branch_Name(dst_state->branch); - if (dst_state->branch == Branch::kNeither) { - dst_state->branch = src_state.branch; - } else if (src_state.branch != dst_state->branch && - src_state.branch != Branch::kNeither) { - if (IsMerge(dst)) { - dst_state->branch = Branch::kBoth; - } else { - return errors::Internal("Illegal merge:\n", src_state.ToString(), - " with ", dst_state->ToString(), " for\n", - dst->DebugString()); - } - } - ++dst_state->count; - return Status::OK(); -} - -StatusOr> -FunctionalizeCond::DeterminePredicateSwitchOrder() { - struct Cluster { - bool operator==(const Cluster& other) const { - return representative == other.representative; - } - int representative = -1; - }; - - // Perform a DFS over the graph and - // * Determine the reverse topological order of the nodes (there should be no - // cycles at this point so the post-order numbering corresponds to the - // reverse topological sorting); - // * Identify dead switches; - // * Initialize the cluster's representative; - std::vector> clusters(graph_->num_node_ids()); - std::vector dead_switches; - std::vector switch_order; - std::vector rev_topo_sorted_nodes; - DFS(*graph_, nullptr, [&](Node* n) { - clusters[n->id()].Get().representative = n->id(); - if (IsSwitch(n)) { - if (IsDeadSwitch(n)) { - dead_switches.push_back(n); - } else { - rev_topo_sorted_nodes.push_back(n); - switch_order.push_back(n); - } - } else if (n->IsOp()) { - // Exclude src and sink nodes from further consideration. - rev_topo_sorted_nodes.push_back(n); - } - }); - - std::vector switch_clusters; - // Return early if there are no switches in the graph. - if (switch_order.empty()) { - return switch_clusters; - } - - // Remove all dead switch nodes. - for (Node* n : dead_switches) { - VLOG(2) << "Removing dead switch: " << n->DebugString(); - graph_->RemoveNode(n); - } - - // Identify switch nodes that are part of the same control flow context by - // considering the operands of operations: an operation is part of the same - // control context as its operands unless the operation is a switch. Control - // dependencies are considered part of the same control flow context if the - // switch depth is the same (see comment below). - - // entry_cluster records the input cluster to a switch node. This is used when - // merging with a merge node where the dst's cluster is merged with the entry - // cluster of the merge node's cluster (which corresponds to a switch cluster - // and so has an entry cluster). - std::unordered_map*> entry_cluster; - - // Returns the output cluster of a node. Where the output cluster is cluster - // where the output of the node is used. For non-merge nodes this is simply - // the cluster they are part of, while for merge nodes it is the entry cluster - // of the cluster they are part of (this will correspond to the entry node of - // a switch node that dominates the merge). - auto find_output_cluster = [&](Node* n) { - UnionFind* cluster = &clusters[n->id()]; - if (!IsMerge(n)) return cluster; - auto it = entry_cluster.find(clusters[n->id()].Get().representative); - // If the cluster is not found in the entry_cluster map then an - // instruction not dominated by a switch node has been merged into the - // cluster of the merge. This indicates a failure of the clustering. - CHECK(it != entry_cluster.end()) - << "Unable to find entry for n=" << n->id() << " (" - << cluster->Get().representative << ")"; - return it->second; - }; - - // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. - std::vector switch_depth(graph_->num_node_ids()); - for (auto it = rev_topo_sorted_nodes.rbegin(); - it != rev_topo_sorted_nodes.rend(); ++it) { - Node* n = *it; - - // Compute switch depth. - int new_switch_depth = 0; - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - new_switch_depth = std::max( - new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); - } - switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); - - // Only merge the input operands of a switch. The switch's clustering itself - // is determined by the interaction of the switch's outputs. - if (IsSwitch(n)) { - Node* input; - TF_CHECK_OK(n->input_node(0, &input)); - entry_cluster[n->id()] = find_output_cluster(input); - UnionFind* cluster = entry_cluster[n->id()]; - int cluster_depth = switch_depth[cluster->Get().representative]; - // Merge the inputs of the switch node with one another. This results in - // predicates and control input residing in the same cluster. - for (const Edge* e : n->in_edges()) { - // Only consider the data inputs to the Switch node. - if (e->IsControlEdge()) continue; - - Node* src = e->src(); - UnionFind* src_cluster = find_output_cluster(src); - int src_cluster_depth = switch_depth[src_cluster->Get().representative]; - if (cluster_depth != src_cluster_depth) { - return errors::InvalidArgument( - "Unable to functionalize control flow in graph: Switch ('", - n->name(), "') has operands ('", input->name(), "' and '", - src->name(), "') that have different switch depths (", - cluster_depth, " != ", src_cluster_depth, ")"); - } - cluster->Merge(src_cluster); - } - continue; - } - - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - if (!src->IsOp()) continue; - UnionFind* cluster = find_output_cluster(src); - // Merge a node with its data operands and with its control operands if - // the src and dst are in the same ControlContext. The ControlContext is - // not explicitly available here, and instead the switch depth is used as - // a proxy here. Due to the invariant that control edges can only be from - // a containing scope to an inner scope or from the inner scope to its - // containing scope (for exit nodes), the switch depth will only match if - // the src and dst are in the same ControlContext. Control edges between - // ControlContexts are handled during the extraction. - int src_id = cluster->Get().representative; - int src_depth = switch_depth[src_id]; - if (!e->IsControlEdge() || new_switch_depth == src_depth) { - if (src_depth != new_switch_depth) { - // TODO(b/77601805) remove this when outside_compilation supports - // control flow. - if (str_util::StrContains(src->name(), "outside_compilation") || - str_util::StrContains(n->name(), "outside_compilation")) { - return errors::InvalidArgument( - "outside_compilation is not yet supported within TensorFlow " - "control flow constructs b/77601805"); - } - return errors::InvalidArgument( - "Unable to functionalize control flow in graph: Operand ('", - src->name(), "') and operator ('", n->name(), - "') have different switch depths (", src_depth, - " != ", new_switch_depth, ")"); - } - cluster->Merge(&clusters[n->id()]); - } - } - } - - if (dump_graphs_) { - // Mark the switch cluster each node is part of. - for (Node* n : graph_->nodes()) { - n->ClearAttr("_XlaFunctionalizeSwitchGroup"); - n->AddAttr("_XlaFunctionalizeSwitchGroup", - clusters[n->id()].Get().representative); - } - LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " - << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, - library_); - } - - // Verify all the nodes of a cluster are at the same depth. - std::unordered_map> cluster_to_depth_node; - for (Node* n : graph_->nodes()) { - int depth = switch_depth[n->id()]; - int cluster_rep = clusters[n->id()].Get().representative; - auto it = cluster_to_depth_node.find(cluster_rep); - if (it == cluster_to_depth_node.end()) { - cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); - } else { - if (it->second.first != depth) { - return errors::Internal( - "Illegal clustering created, mismatch in depths:", "\n\t", - n->DebugString(), "(", clusters[n->id()].Get().representative, - ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), - "(", clusters[n->id()].Get().representative, ") at depth ", - it->second.first); - } - } - } - - struct Hash { - size_t operator()(const std::pair& item) const { - return Hash64Combine(hash()(item.first), - std::hash()(item.second.representative)); - } - }; - - // Merge Switch nodes with common predicate. - std::unordered_map, int, Hash> predicate_index; - // The nodes in switch_order are in reverse topological order, but the - // clustered switches need not be (i.e., when considered as a cluster one - // element of a cluster may be later in the topological order than another - // node whose cluster is later in the topological order of clustered - // switches). - for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { - const Edge* pred_edge; - TF_CHECK_OK((*it)->input_edge(1, &pred_edge)); - // The predicate can be preceded by a identity node. Look through identity - // nodes to predicate. - while (pred_edge->src()->IsIdentity()) { - TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge)); - } - auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get()); - if (predicate_index.find(repr) == predicate_index.end()) { - predicate_index[repr] = switch_clusters.size(); - switch_clusters.emplace_back(pred_edge); - // Generate a name by concatenating with the cluster representative as - // there could be multiple switch clusters with the same predicate. - switch_clusters[predicate_index[repr]].name = strings::StrCat( - pred_edge->src()->name(), "_", repr.second.representative, "_If"); - } - switch_clusters[predicate_index[repr]].switches.push_back(*it); - } - - return switch_clusters; -} - -StatusOr> -FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes( - const std::unordered_map& branch_map, - const std::vector& switches) { - std::vector old_control_nodes; - for (const auto& kv : branch_map) { - if (kv.second.count != kv.first->in_edges().size()) { - std::vector delete_edges; - for (const Edge* in : kv.first->in_edges()) { - auto it = branch_map.find(in->src()); - if (it == branch_map.end()) { - if (in->IsControlEdge()) { - old_control_nodes.push_back(in->src()); - delete_edges.push_back(in); - } else { - if (IsSwitch(in->src())) { - if (std::find(switches.begin(), switches.end(), in->src()) == - switches.end()) { - return errors::Internal( - "Unexpected switch node found during flow forward: ", - in->src()->DebugString()); - } - continue; - } - return errors::InvalidArgument( - "Value ", kv.first->name(), "'s input, ", in->src()->name(), - ", is not dominated by switch nodes ", NodesToString(switches)); - } - } - } - // Remove control edges from nodes that are not dominated by the switch - // nodes. New control dependencies will be added between these nodes and - // the XlaIf node inserted. - for (const Edge* e : delete_edges) { - graph_->RemoveEdge(e); - } - } - } - return old_control_nodes; -} - -StatusOr< - std::pair, - std::unordered_set>> -FunctionalizeCond::DetermineBranchMapAndFrontier( - const SwitchCluster& switch_cluster) { - std::unordered_map branch_map; - std::unordered_set frontier; - std::vector stack = switch_cluster.switches; - std::vector visited(graph_->num_node_ids(), false); - while (!stack.empty()) { - Node* n = stack.back(); - stack.pop_back(); - - if (visited[n->id()]) { - continue; - } - visited[n->id()] = true; - - // Propagate branch state along each edge of a switch node. - bool sink_only = true; - for (const Edge* e : n->out_edges()) { - Node* out = e->dst(); - if (!out->IsOp()) { - continue; - } - sink_only = false; - // Propagate branch information. - ForwardFlowNode& ffn = branch_map[out]; - if (IsSwitch(n)) { - int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ", - e->DebugString()); - } else { - TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn), - " when joining ", e->DebugString()); - } - if (IsMerge(out)) { - if (out->in_edges().size() == ffn.count) { - frontier.insert(out); - } - } else if (!visited[out->id()]) { - stack.push_back(out); - } - } - if (sink_only) { - if (!IsIdentity(n)) { - VLOG(1) << "Feeding into sink: " << n->DebugString(); - } - } - } - - if (dump_graphs_) { - for (const auto& kv : branch_map) { - // Append attribute to the graph if running with logging to make the - // changes clearer in the visualization. - kv.first->AddAttr("_XlaFunctionalizeBranch", - Branch_Name(kv.second.branch)); - } - } - return std::make_pair(std::move(branch_map), std::move(frontier)); -} - -Status FunctionalizeCond::FunctionalizeInternal() { - TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, - DeterminePredicateSwitchOrder()); - - // Iterate from innermost set of clustered switches to outermost, replacing - // matching switch->merge subgraphs with single XlaIf nodes. - for (auto it = predicate_switch_order.rbegin(); - it != predicate_switch_order.rend(); ++it) { - auto& ps = *it; - VLOG(3) << "Flow down from: " << ps.ToString(); - - std::unordered_map branch_map; - std::unordered_set frontier; - TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), - DetermineBranchMapAndFrontier(ps)); - - if (dump_graphs_) - LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, - library_); - TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); - - struct Hash { - size_t operator()(const std::pair& item) const { - return Hash64Combine(hash()(item.first), - std::hash()(item.second)); - } - }; - - // Sort the merge and switch nodes using NodeCmp. The switch-nodes are - // further grouped (post sorting) by input to the switch node as in the - // functionalized form each input will be passed in only once. This grouping - // should retain the sorted order. - CondArgNodes cond_arg_nodes; - std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); - std::unordered_map, int, Hash> input_index; - for (Node* switch_node : ps.switches) { - const Edge* e; - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); - std::pair key = std::make_pair(e->src(), e->src_output()); - if (input_index.find(key) == input_index.end()) { - input_index[key] = cond_arg_nodes.size(); - cond_arg_nodes.emplace_back(key.first, key.second); - } - cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node); - } - std::vector merge_nodes(frontier.begin(), frontier.end()); - std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); - - TF_ASSIGN_OR_RETURN(std::vector old_control_nodes, - EnsureDominanceAndReturnNonDominatedControlNodes( - branch_map, ps.switches)); - - TF_ASSIGN_OR_RETURN(Node * if_node, - ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); - for (Node* old : old_control_nodes) { - graph_->AddControlEdge(old, if_node); - } - - for (auto& del_kv : branch_map) { - graph_->RemoveNode(del_kv.first); - } - for (auto& kv : cond_arg_nodes) { - for (Node* node : kv.switches) { - graph_->RemoveNode(node); - } - } - if (dump_graphs_) - LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, - library_); - } - return Status::OK(); -} - -StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, - const std::vector& merge_nodes) { - VLOG(2) << "Build if op for " << switch_cluster.name; - - NodeDef if_def; - // Create a new If node using the name of the merge node. - NodeDefBuilder builder(switch_cluster.name, "XlaIf"); - string branch[] = {"else_branch", "then_branch"}; - for (int i = 0; i < 2; ++i) { - static std::atomic sequence_num(0LL); - int64 id = ++sequence_num; - - NameAttrList body_name; - body_name.set_name( - strings::StrCat("_functionalize_if_", branch[i], "_", id)); - auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, - merge_nodes, i, body.get())); - VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); - FunctionDef body_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); - TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); - builder.Attr(branch[i], body_name); - } - - // Build input type. - std::vector inputs; - DataTypeVector in_arg_types; - for (auto& kv : cond_arg_nodes) { - bool inserted = false; - for (const Node* arg : kv.switches) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - if (!inserted) { - DataType dtype = arg->input_type(0); - inputs.emplace_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), dtype)); - in_arg_types.push_back(dtype); - inserted = true; - } - } - } - } - builder.Attr("Tin", in_arg_types); - - // Build output type. - DataTypeVector out_type; - for (const Node* merge : merge_nodes) { - DataType dtype = merge->output_type(0); - out_type.push_back(dtype); - } - builder.Attr("Tout", out_type); - - builder.Attr("Tcond", DT_BOOL); - builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name()); - // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut( - switch_cluster.predicate_edge->src()->name(), - switch_cluster.predicate_edge->src_output(), - switch_cluster.predicate_edge->src()->output_type(0))); - // ... followed by the other inputs. - builder.Input(inputs); - - TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); - return if_node; -} - -Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switches, - const std::vector& merge_nodes, - int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " - << input_edge; - std::vector squash_src_outputs(graph_->num_node_ids(), false); - std::vector node_map(graph_->num_node_ids(), nullptr); - int arg_count = 0; - for (auto& kv : cond_arg_nodes) { - Node* arg_node = nullptr; - for (const auto* arg : kv.switches) { - DataType dtype = arg->input_type(0); - if (arg_node == nullptr) { - TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); - } - node_map.at(arg->id()) = arg_node; - squash_src_outputs.at(arg->id()) = true; - } - } - - std::vector stack; - stack.reserve(merge_nodes.size()); - for (int j = 0; j < merge_nodes.size(); ++j) { - Node* node = merge_nodes[j]; - TF_ASSIGN_OR_RETURN(node_map.at(node->id()), - BuildRetvalNode(body, node->output_type(0), - /*index=*/j)); - const Edge* in_edge; - TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); - Node* in = in_edge->src(); - if (node_map.at(in->id()) == nullptr) { - node_map.at(in->id()) = body->CopyNode(in); - } - - if (std::find(switches.begin(), switches.end(), in) == switches.end()) { - body->AddEdge(node_map.at(in->id()), in_edge->src_output(), - node_map.at(node->id()), 0); - } else { - body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); - // Don't include input nodes that are already just returned in stack. - continue; - } - stack.push_back(in); - } - - return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, - body); -} - -Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, - const Edge* predicate_edge, - Node* if_node) { - VLOG(3) << "AddInputEdges for " << if_node->name(); - int index = 0; - graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node, - index++); - for (auto& arg : cond_arg_nodes) { - if (arg.src_output == Graph::kControlSlot) { - graph_->AddControlEdge(arg.src, if_node); - } else { - graph_->AddEdge(arg.src, arg.src_output, if_node, index++); - } - } - return Status::OK(); -} - -Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, - Node* if_node) { - VLOG(3) << "AddOutputEdges for " << if_node->name(); - for (int i = 0; i < outputs.size(); ++i) { - Node* node = outputs[i]; - std::vector edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - - if (edge->src_output() > 0) { - return errors::Unimplemented("Output of index (", edge->src_output(), - ") of merge node ", node->name()); - } - - int src_output = - dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; - graph_->RemoveEdge(edge); - graph_->AddEdge(if_node, src_output, dst, dst_input); - } - } - return Status::OK(); -} - -StatusOr FunctionalizeCond::ConvertToXlaIf( - const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, - const std::vector& merge_nodes) { - VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " - << NodesToString(merge_nodes); - - // Extract bodies and builds a If operator. - TF_ASSIGN_OR_RETURN( - Node * if_node, - BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); - TF_RETURN_IF_ERROR( - AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); - // Check that the if_node doesn't feed into itself. - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNoCycleContains(if_node, graph_->num_node_ids()), - "ConvertToXlaIf failed."); - - return if_node; -} - -Status FunctionalizeCond::Functionalize(Graph* graph, - FunctionLibraryDefinition* library) { - VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); - return fc.FunctionalizeInternal(); -} - -} // namespace - // Transformation that converts TensorFlow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, @@ -1455,104 +75,174 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } -Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, - Graph* graph, - FunctionLibraryDefinition* library) { - VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph, - library); - - // Note: BuildControlFlowInfo() requires that the graph's source node is - // connected to all source nodes in the graph. Many graphs violate this - // invariant. - std::vector cf_info; - std::vector unreachable_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes), - "FunctionalizeControlFlow failed"); - if (!unreachable_nodes.empty()) { - return errors::InvalidArgument( - "The following nodes are unreachable from the source in the graph: ", - errors::FormatNodeNamesForError(unreachable_nodes)); - } - - // Builds Frames, indexed by name. - std::unordered_map frames; - for (Node* node : graph->op_nodes()) { - const ControlFlowInfo& cf = cf_info[node->id()]; - - VLOG(2) << "node: " << node->name() << " (" << node->id() - << ") frame_name: " << cf.frame_name - << " frame: " << (cf.frame ? cf.frame->name() : "---") - << " parent_frame: " - << (cf.parent_frame ? cf.parent_frame->name() : "---"); - TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); - - Frame& frame = frames[cf.frame_name]; - Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; - if (frame.parent == nullptr) { - frame.parent = parent; - frame.name = cf.frame_name; - ++parent->num_children; - } - - if (IsEnter(node)) { - Arg arg; - arg.enter = node; - TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", - &arg.is_loop_invariant)); - frame.args.push_back(arg); - } else if (IsLoopCond(node)) { - frame.loop_cond = node; - } - frame.nodes.insert(node); - } - - // Adds frames with no children (i.e., the innermost frames) to a worklist. - std::deque worklist; - for (auto& frame : frames) { - if (frame.second.num_children == 0) { - worklist.push_back(&frame.second); - } - } - - // Eliminate loops from innermost to outermost. - while (!worklist.empty()) { - Frame* frame = worklist.front(); - worklist.pop_front(); - if (frame->parent == frame) { - // Skip the root frame. - continue; +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + std::map* canonicalized_name_to_new_name) { + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); } - + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + + // If any node has associated functions, functionalize them first. + // Gather nodes with associated functions first, because rewriting those nodes + // might involve node deletion/addition. Avoid modifying nodes while iterating + // it. + std::vector>> + nodes_to_associated_functions; + for (auto* n : body->graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, flr); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (auto iter : nodes_to_associated_functions) { + Node* n = iter.first; + auto associated_functions = iter.second; + for (auto& associated_function : associated_functions) { + string name = associated_function.func_name(); + string canonicalized_name = + Canonicalize(name, AttrSlice(&associated_function.attrs())); + auto iter = canonicalized_name_to_new_name->find(canonicalized_name); + string new_name; + if (iter != canonicalized_name_to_new_name->end()) { + // If we already functionalized this function, skip functionalization + // but still rewrite the node. + new_name = iter->second; + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + name, new_name, associated_function.attrs(), fld, flr, + canonicalized_name_to_new_name)); + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } + // Notice that if "n" is a function call, RewriteAssociatedFunction() will + // delete it and create a new node instead, making "n" an invalid pointer. + // That's fine because in that case, associated_functions will only have + // one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + body->graph, n, fld, associated_function, new_name)); + } + } + + // Call graph optimizer. The most important optimization we need is constant + // folding, which will replace ops like Shape/BroadcastGradientArgs with + // constant shape input. Without this optimization, those ops might become + // dynamic input for then/else body function and XLA will complain that input + // is not compile time constant. We enable function inlining as well, because + // otherwise we won't be able to infer shape for any node depending on + // function call nodes. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_opt_", func_name), + *body->graph, fld); + } + // Optimizer accepts std::unique_ptr* as input and might change + // underlying pointer, thus we create a new Graph and copy from body->graph. + std::unique_ptr optimized_graph(new Graph(fld)); + CopyGraph(*body->graph, optimized_graph.get()); + OptimizerOptions opts; + opts.set_opt_level(OptimizerOptions::L0); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + auto cf_consider_fn = [](const Node* n) { + // Skip SymbolicGradient op when doing constant folding. + // Enabling SymbolicGradient op in constant folding requires + // flr->device() to be non-null, and here we have not constructed + // proper Device object yet (it will be constructed in XlaCompiler). + return n->type_string() != FunctionLibraryDefinition::kGradientOp; + }; + optimizer.Optimize(flr, flr->env(), + /*device=*/nullptr, &optimized_graph, + /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, + cf_consider_fn); + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *optimized_graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *optimized_graph, fld); + } + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, + &functionalized_fdef)); + + // Add rewritten FunctionDef into library. + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); - - // If the parent has no remaining children, add it to the worklist. - --frame->parent->num_children; - if (frame->parent->num_children == 0) { - worklist.push_back(frame->parent); + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + + return ret_status; +} + +Status FunctionalizeControlFlowPass::Run( + const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); + } + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // Find XLA compile ops and its corresponding FunctionDef. + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ + {"TPUCompile", "function"}, + {"XlaLaunch", "function"}, + }; + std::map canonicalized_name_to_new_name; + for (Node* n : graph->nodes()) { + auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); + if (it == kNodeTypeToFunctionAttrMapping->end()) { + continue; } - } - // There should be no cycle at this point, since while loops have been removed - // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. - for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNoCycleContains(node, graph->num_node_ids()), - "FunctionalizeLoop failed."); + const string func_attr = it->second; + if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != + kNodeTypeToFunctionAttrMapping->end()) { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name)); + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); } } - // FunctionalizeControlFlow is invoked for every function, so the loops's - // bodies and conditionals that were extracted into functions will be handled - // in successive invocations. - TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); - - VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph, - library); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); + } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index d941041d15532446d1413f16fe64602bfb1a7daa..ba99205640ccdc83a3a4d50e3ec474907894a835 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -16,20 +16,31 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. If lookup_library is provided, use -// it to make the library for control flow self-contained. +// operators and tf.cond() conditionals into function If operators, suitable for +// XLA compilation. If lookup_library is provided, use it to make the library +// for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (If/While). +class FunctionalizeControlFlowPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc new file mode 100644 index 0000000000000000000000000000000000000000..a10a9d0499457bbc0383ea3a8c678f153e21894b --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +namespace tensorflow { + +// This pass is required for some AOT backends and all JIT backends, so this +// file exists as a separate lib and will be linked to both AOT and JIT. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, + FunctionalizeControlFlowPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index ccf249b35d66861888ad5e5e904b5f63b8ac50a1..c3841f996f801e855da75b23f01d41674ec51c4d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" @@ -37,12 +38,12 @@ limitations under the License. namespace tensorflow { namespace { -// Returns the names of the "then" and "else" functions for the XlaIf node in a +// Returns the names of the "then" and "else" functions for the If node in a // graph. Status FindIfThenAndElse(const GraphDef& graph, string* op_name, NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaIf") { + if (node.op() == "If") { *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); @@ -52,7 +53,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, return Status::OK(); } } - return errors::NotFound("No XlaIf node found in graph"); + return errors::NotFound("No If node found in graph"); } // Graph: @@ -112,9 +113,10 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, then_fn, - else_fn, {DT_INT32}); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -172,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) { Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, NameAttrList* body) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaWhile") { + if (node.op() == "While") { const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); *cond = *result; @@ -181,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, return Status::OK(); } } - return errors::NotFound("No XlaWhile node found in graph"); + return errors::NotFound("No While node found in graph"); } // Graph: @@ -250,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -387,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_ASSERT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -478,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -620,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); GraphDef expected; @@ -800,11 +802,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); @@ -818,7 +820,7 @@ TEST(FunctionalizeControlFlow, Complex) { scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ exit_j.output.op(), exit_k.output.op()}), identity_i, one_outer); auto next_iteration_i = @@ -859,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -916,15 +918,15 @@ TEST(FunctionalizeControlFlow, Complex) { auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = - ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ while_op[0].op(), while_op[1].op()}), identity_i, one_outer); @@ -986,11 +988,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); @@ -1013,63 +1015,5 @@ TEST(FunctionalizeControlFlow, Complex) { } } -TEST(FunctionalizeControlFlow, Cycle) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - // ----------------------------------------------------- - // | | - // | v - // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 - // | ^ | - // | | v - // --------> one -------------------------> add_2 ---> merge_2 - { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); - auto two = - ops::Const(scope.WithOpName("cond/two") - .WithControlDependencies(switch_1.output_true), - 2); - auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), - switch_1.output_true, two); - auto one = - ops::Const(scope.WithOpName("cond/one") - .WithControlDependencies(switch_1.output_false), - 1); - auto add = ops::Add(scope.WithOpName("cond/false/add"), - switch_1.output_false, one); - - auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), - std::initializer_list{add, mul}); - auto identity = - ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); - auto switch_2 = - ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); - auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), - switch_2.output_false, one); - auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), - switch_2.output_true, two); - auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), - std::initializer_list{add_2, mul_2}); - TF_ASSERT_OK(scope.ToGraph(graph.get())); - } - // No cycle before functionalize control flow. - TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - // switch_1 and switch_2 have the same switch depth. They are replaced by a - // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: - // less -> XlaIf <--> identity. - Status status = FunctionalizeControlFlow(graph.get(), &library); - EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle")) - << status.error_message(); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}")) - << status.error_message(); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..54cebc61778ba051b9c903f8e2c3696cec69843a --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" + +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { + +bool NodeCmpByNameResourcesLast::operator()(const Node* lhs, + const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); +} + +xla::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph) { + Status status; + Node* inserted_node = graph->AddNode(node_def, &status); + if (!status.ok()) { + return status; + } + return inserted_node; +} + +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { + const char* const kRetValOp = "_Retval"; + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(absl::StrCat(kRetValOp, index)); + AddNodeAttr("T", type, &ret_def); + AddNodeAttr("index", index, &ret_def); + return AddNodeDefToGraph(ret_def, graph); +} + +// Check that the graph has no cycle containing the given node. +Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { + std::vector ready; + ready.push_back(node); + std::vector visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), + " (", node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..582b49d5116acc651fb6242b5c2b9aeeac269532 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" + +// Utility functions shared between functionalize cond and while. + +namespace tensorflow { + +// Check that the graph has no cycle containing the given node. +Status CheckNodeNotInCycle(const Node* node, const int num_nodes); + +// Comparison function used for sorting nodes consistently. +// a) resource variables are last, and +// b) sort lexicographically by name (for deterministic output). +struct NodeCmpByNameResourcesLast { + bool operator()(const Node* lhs, const Node* rhs) const; +}; + +// Returns the Node* created from the NodeDef in the Graph. +xla::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph); + +// Build a retval node of given type and index. +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); + +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return absl::StrCat("{", + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + absl::StrAppend(output, node->name()); + }), + "}"); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc new file mode 100644 index 0000000000000000000000000000000000000000..7c3ad448ef546dd1ab2640a57d7d1d73ca3768ad --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -0,0 +1,677 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_while.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +using xla::StatusOr; + +// Information about a loop argument. +struct Arg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Copies a subgraph from `graph` to `output` by performing a reverse DFS +// starting at nodes in vector `stack`. +// `node_map` is a vector indexed by source node ID to dest nodes. +// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` +// before the traversal clients can cut the graph. If a frame is provided (frame +// != nullptr), then this functions will return an error if the +// traversal leaves 'frame'; the client must add enough nodes to `node_map` to +// cut the graph and prevent the traversal from escaping. +// +// `squash_src_outputs` contains a bool for each source node ID. If true, then +// the source output on that node will be replaced by zero when copied. This is +// used when replacing a Switch node with an _Arg node. The output we are +// taking from the Switch node was not necessarily the first output, but _Arg +// nodes only have one output. By adding the Switch node to `squash_src_outputs` +// we rewrite the src_output of the corresponding edge to be 0. +Status CopySubgraph(const Graph& graph, const Frame* frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { + VLOG(3) << "Stack: " << NodesToString(stack); + std::vector visited(graph.num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + VLOG(5) << "Copying node " << n->name(); + + if (visited[n->id()]) continue; + visited[n->id()] = true; + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { + // We traversed out of the loop frame, without encountering a cut node. + return errors::Internal("Graph traversal of loop frame ", frame->name, + " escaped frame at ", src->name(), + " without encountering an argument node."); + } + if ((*node_map)[src->id()] == nullptr) { + (*node_map)[src->id()] = output->CopyNode(src); + stack.push_back(src); + } + Node* src_copy = (*node_map)[e->src()->id()]; + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); + Node* dst_copy = (*node_map)[e->dst()->id()]; + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + return Status::OK(); +} + +StatusOr BuildArgNode(Graph* graph, DataType type, int index) { + const char* const kArgOp = "_Arg"; + NodeDef arg_def; + NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp); + builder.Attr("T", type); + builder.Attr("index", index); + TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); + return AddNodeDefToGraph(arg_def, graph); +} + +// Builds a graph for the loop condition. +Status BuildLoopCondition(const Graph& graph, Frame* frame, + std::unique_ptr* cond_output) { + VLOG(2) << "Building loop condition for " << frame->name; + *cond_output = absl::make_unique(graph.op_registry()); + Graph* output = cond_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(output, arg.enter->input_type(0), i)); + if (arg.is_loop_invariant) { + node_map[arg.enter->id()] = arg_node; + } else { + node_map[arg.merge->id()] = arg_node; + } + } + + // Build a Retval node for the loop condition. The LoopCond nodes are always + // boolean because of the type constraints on the LoopCond op. + TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], + BuildRetvalNode(output, DT_BOOL, 0)); + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, + &node_map, output); +} + +// Builds a graph for the loop body. +Status BuildLoopBody(const Graph& graph, Frame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { + VLOG(2) << "Building loop body for " << frame->name; + *body_output = absl::make_unique(graph.op_registry()); + Graph* output = body_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + std::vector next_iterations; + next_iterations.reserve(frame->args.size()); + arg_types->reserve(frame->args.size()); + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + DataType dtype = arg.enter->input_type(0); + arg_types->push_back(dtype); + + TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); + + if (dtype == DT_RESOURCE) { + // The convention of the XLA bridge is that resource variable arguments + // are only inputs to the loop body and have no corresponding output. + // TODO(b/37741920): change the convention so that DT_RESOURCE variables + // are both inputs and outputs, and then remove this case. + TF_RET_CHECK(arg.is_loop_invariant); + node_map[arg.enter->id()] = arg_node; + } else { + TF_ASSIGN_OR_RETURN(Node * retval_node, + BuildRetvalNode(output, dtype, i)); + + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. + node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); + } else { + // Argument is loop-varying. + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); + } + } + } + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function referred by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, + FunctionLibraryDefinition* library) { + VLOG(2) << "Frame " << frame->name << " before: " + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); + + // Split loop-varying Enter nodes with multiple successors. If the same + // Tensor is fed as input to multiple loop arguments, we may end up with a + // shared Enter node. We clone Enter nodes with multiple successors to + // maintain the invariant of a unique Enter node per argument of the final + // loop. + std::vector args; + for (const Arg& arg : frame->args) { + if (arg.is_loop_invariant) { + args.push_back(arg); + } else { + std::vector edges(arg.enter->out_edges().begin(), + arg.enter->out_edges().end()); + for (int i = 0; i < edges.size(); ++i) { + if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { + continue; + } + TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); + Arg new_arg; + new_arg.is_loop_invariant = false; + if (i == 0) { + new_arg.enter = arg.enter; + } else { + new_arg.enter = graph->CopyNode(arg.enter); + frame->nodes.insert(new_arg.enter); + for (Edge const* e : arg.enter->in_edges()) { + graph->AddEdge(e->src(), e->src_output(), new_arg.enter, + e->IsControlEdge() ? Graph::kControlSlot : 0); + } + Node* dst = edges[i]->dst(); + int dst_input = edges[i]->dst_input(); + graph->RemoveEdge(edges[i]); + graph->AddEdge(new_arg.enter, 0, dst, dst_input); + } + args.push_back(new_arg); + } + } + } + frame->args = std::move(args); + + std::sort(frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { + return NodeCmpByNameResourcesLast()(a.enter, b.enter); + }); + + if (frame->loop_cond == nullptr) { + return errors::InvalidArgument("Loop ", frame->name, + " has no LoopCond node"); + } + + // Find the set of Switch nodes that are successors of the LoopCond. + std::unordered_set switches; + for (const Edge* edge : frame->loop_cond->out_edges()) { + if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && + edge->dst_input() == 1) { + switches.insert(edge->dst()); + } + } + + // For each non-constant argument, looks for the following pattern of nodes: + // Enter ----> Merge --------> Switch --> Exit + // ^ ^ + // | | + // NextIteration LoopCond + // ^ ^ + // | | + // ... ... + for (Arg& arg : frame->args) { + if (!arg.is_loop_invariant) { + // Follow the edge from the Enter to Merge. + const Edge* enter_merge = nullptr; + for (const Edge* e : arg.enter->out_edges()) { + // Ignore control-edges to the sink node. These are allowed by the + // graph invariants, although probably they should have been stripped + // off earlier. + if (e->IsControlEdge() && e->dst()->IsSink()) { + continue; + } + if (enter_merge != nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has multiple successors: ", + FormatNodeForError(*enter_merge->dst()), + " and ", FormatNodeForError(*e->dst())); + } + enter_merge = e; + } + if (enter_merge == nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has zero successors"); + } + arg.merge = enter_merge->dst(); + if (!IsMerge(arg.merge)) { + return errors::InvalidArgument( + "Successor of Enter node for loop-varying argument ", + FormatNodeForError(*arg.merge), + " is not a Merge node; got: ", arg.merge->type_string()); + } + + // Find the NextIteration from the merge. There should be two inputs to + // the Merge and the NextIteration should be the other input. + if (arg.merge->input_types().size() != 2) { + return errors::InvalidArgument( + "Unexpected number of inputs to Merge node for loop-varying " + "argument ", + FormatNodeForError(*arg.merge), "; expected 2, got ", + arg.merge->input_types().size()); + } + TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), + &arg.next_iteration)); + if (!IsNextIteration(arg.next_iteration)) { + return errors::InvalidArgument( + "Expected NextIteration node as input to Merge node; got node ", + FormatNodeForError(*arg.next_iteration), " with kind ", + arg.next_iteration->type_string()); + } + + // Find the Switch successor of the Merge. There should be exactly one + // Switch node that is a successor of both the Merge and the LoopCond. + for (const Edge* edge : arg.merge->out_edges()) { + if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && + switches.find(edge->dst()) != switches.end()) { + if (arg.switch_node != nullptr) { + return errors::InvalidArgument("Duplicate Switch successors to ", + FormatNodeForError(*arg.merge)); + } + arg.switch_node = edge->dst(); + } + } + if (arg.switch_node == nullptr) { + return errors::InvalidArgument("Missing Switch successor to ", + FormatNodeForError(*arg.merge)); + } + + // Update the device on the Identity outputs of the switch to match their + // target. These Identity outputs do not + + // Loop over the switch node's output to: + // - Find the Exit successor. + // - Set the sharding on all Identity outputs of the switch. These + // identity nodes are values used by the loop body or condition. + // The Identity node may have the wrong device so copy the device from + // one of its outputs instead. + std::deque possible_exit; + for (const Edge* edge : arg.switch_node->out_edges()) { + if (edge->src_output() == 0) { + possible_exit.push_back(edge); + } + if (IsIdentity(edge->dst())) { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } + } + // TODO(b/67425339): Allow general graph between switch and exit. + while (!possible_exit.empty()) { + const Edge* edge = possible_exit.front(); + possible_exit.pop_front(); + if (IsExit(edge->dst())) { + if (arg.exit != nullptr) { + return errors::InvalidArgument( + "Duplicate Exit successors to ", + FormatNodeForError(*arg.switch_node)); + } + arg.exit = edge->dst(); + } else { + if (!IsIdentity(edge->dst())) { + return errors::Unimplemented("General graph between switch (", + FormatNodeForError(*arg.switch_node), + ") and exit node of frame ", + frame->name, " not supported yet."); + } + for (const Edge* out : edge->dst()->out_edges()) { + possible_exit.push_back(out); + } + } + } + } + } + + // Builds the condition and body functions. Notice that we call + // FunctionalizeCond() on cond_graph and body_graph because we might have + // unfunctionalized "if" in cond_graph and body_graph. Functionalize them + // before they are encapsulated in FunctionDef. + std::unique_ptr cond_graph; + TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + FixupSourceAndSinkEdges(cond_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); + DataTypeVector arg_types; + std::unique_ptr body_graph; + TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + FixupSourceAndSinkEdges(body_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); + + VLOG(2) << "Frame " << frame->name << " condition: " + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) + << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + NameAttrList cond_name; + cond_name.set_name(absl::StrCat("_functionalize_cond_", id)); + NameAttrList body_name; + body_name.set_name(absl::StrCat("_functionalize_body_", id)); + FunctionDef cond_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); + + TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } + + // Builds a While operator. + NodeDef while_def; + NodeDefBuilder builder(frame->loop_cond->name(), "While", library); + builder.Attr("T", arg_types); + builder.Attr("cond", cond_name); + builder.Attr("body", body_name); + std::vector inputs; + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + inputs.push_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), arg_types[i])); + } + } + builder.Input(inputs); + TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); + TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph)); + + // Copies edges to the Enter nodes and from the Exit nodes onto the While. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph->AddControlEdge(in_edge->src(), while_node); + } else { + graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); + } + + if (!arg.is_loop_invariant) { + // Add output edges if the output of the loop is consumed. + if (arg.exit != nullptr) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (dst_input == Graph::kControlSlot) { + graph->AddControlEdge(while_node, dst); + } else { + graph->AddEdge(while_node, i, dst, dst_input); + } + } + } + } + } + + // Remove the old nodes from the graph, and add the while node to the parent + // frame. + for (Node* node : frame->nodes) { + graph->RemoveNode(node); + } + frame->nodes.clear(); + frame->parent->nodes.insert(while_node); + + VLOG(2) << "Frame " << frame->name << " after: " + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); + + return Status::OK(); +} +} // namespace + +Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { + // Note: BuildControlFlowInfo() requires that the graph's source node is + // connected to all source nodes in the graph. Many graphs violate this + // invariant. + std::vector cf_info; + std::vector unreachable_nodes; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + errors::FormatNodeNamesForError(unreachable_nodes)); + } + + // Builds Frames, indexed by name. + std::unordered_map frames; + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } + + if (IsEnter(node)) { + Arg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque worklist; + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Eliminate loops from innermost to outermost. + while (!worklist.empty()) { + Frame* frame = worklist.front(); + worklist.pop_front(); + if (frame->parent == frame) { + // Skip the root frame. + continue; + } + + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } + + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added While nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "While") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(node, graph->num_node_ids()), + "Functionalizing loop failed."); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h new file mode 100644 index 0000000000000000000000000000000000000000..a708c6e4ec4e13527b4ee2d6c435dddee0a2b4e2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_while.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. +Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, FunctionLibraryDefinition* library); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e4fdf0a6186eb69a2e3413838c91616b992ef2d6..c019a28e892ff89f559ddbec2360d6caa9c1808f 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -57,7 +56,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, std::vector compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); for (int i = 0; i < args->size(); ++i) { @@ -80,7 +80,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(auto literal, client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + LiteralToHostTensor(literal, arg.type, &arg.constant_value)); } else { arg.kind = XlaCompiler::Argument::kParameter; } @@ -126,7 +126,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) << "Not supported node: " << n->DebugString(); params.op_kernel = op_kernel.get(); - gtl::InlinedVector output_attr(n->num_outputs()); + absl::InlinedVector output_attr(n->num_outputs()); params.output_attr_array = output_attr.data(); // tensor_inputs_ is a buffer reused across graph traversal. We clean up and @@ -145,6 +145,7 @@ Status GraphCompiler::Compile() { } OpKernelContext op_context(¶ms, n->num_outputs()); + VLOG(3) << "Translating " << params.op_kernel->name(); if (IsFunctional(n)) { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index 127562eb23d775f17179cc9ee968ec2255cf3a14..e9f02201cf6bed5495dff7dff76c5bafe7771516 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -55,17 +55,17 @@ namespace tensorflow { // op registration infrastructure instead of FunctionLibraryRuntime. class GraphCompiler { public: - GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, - Graph* graph, FunctionLibraryRuntime* flib, + GraphCompiler(XlaCompilationDevice* device, Graph* graph, + FunctionLibraryRuntime* flib, ScopedStepContainer* step_container) - : xla_context_(xla_context), - device_(device), + : device_(device), graph_(graph), flib_(flib), step_container_(step_container) {} - // Compiles the graph. The results are written in `xla_context` that is passed - // into the compiler. + // Compiles the graph. The results are written in xla_context stored in the + // resource_manager of the 'XlaCompilationDevice' that's passed into the + // constructor. Status Compile(); private: @@ -82,14 +82,13 @@ class GraphCompiler { // using `compiler_`. Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); - XlaContext* xla_context_; XlaCompilationDevice* device_; Graph* graph_; FunctionLibraryRuntime* flib_; ScopedStepContainer* step_container_; // A buffer to hold tensor inputs to a node, this is reused across the graph // traversal. - gtl::InlinedVector tensor_inputs_; + absl::InlinedVector tensor_inputs_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 0609e223381550645d1a41ba75e4cd57f893ee95..3e823254d3d52e88552712b4f53fa4449586cd20 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -6,6 +6,10 @@ package( load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) tf_kernel_library( name = "xla_ops", @@ -18,6 +22,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "broadcast_to_op.cc", "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", @@ -96,12 +101,19 @@ tf_kernel_library( "unary_ops.cc", "unpack_op.cc", "variable_ops.cc", + "xla_broadcast_helper_op.cc", + "xla_conv_op.cc", + "xla_dot_op.cc", + "xla_pad_op.cc", + "xla_reduce_op.cc", + "xla_select_and_scatter_op.cc", ], hdrs = [ "index_ops.h", "shape_util.h", ], deps = [ + ":conv_op_helpers", ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", @@ -129,6 +141,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:framework", @@ -154,6 +167,30 @@ tf_kernel_library( "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:transpose_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "conv_op_helpers", + srcs = ["conv_op_helpers.cc"], + hdrs = ["conv_op_helpers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/types:span", ], ) @@ -163,6 +200,7 @@ tf_kernel_library( hdrs = ["while_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", @@ -180,6 +218,7 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 26fc1620a4f032b3af28de6e3a5af0e965e82341..276d744c096f8996c774964204feaa3762bdb844 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp); +REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index b3ad0aea84eef601de08909f760699b8700d28f4..a267c0c72fce67d7c22c55a57f8d5ac4ffd2b7e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 48f2a005ab16651fe29d0f6f9d881f95693da461..a18e04995b5e1e0b0374f7b0edd6f5e114cf994a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -23,10 +23,10 @@ namespace { void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); @@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ba3b1c9dab79a387c48e8e25e4804917f328f8a0..182f7c99344845964f7010127718f876ab6e8a44 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -38,7 +39,7 @@ class BCastArgsOp : public XlaOpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), @@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); const int64 len = bcast.output_shape().size(); Tensor output(DT_INT32, TensorShape({len})); @@ -87,7 +88,7 @@ class BCastGradArgsOp : public XlaOpKernel { ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), @@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); Output(ctx, 0, bcast.grad_x_reduce_idx()); Output(ctx, 1, bcast.grad_y_reduce_idx()); } diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2c328102e0bd84709707f102272691b6aec9a577..a988d3c33ed808b022f67882c8ae5100b7e7a305 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,21 +31,21 @@ namespace { // A subclass of a XlaBinaryOp must build the computation that // describes the (tensor,tensor)->tensor function to apply to each element of // the input. -#define XLA_MAKE_BINARY(NAME, HLO) \ - class NAME##Op : public XlaBinaryOp { \ - public: \ - explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::XlaOp Computation( \ - XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ - const gtl::ArraySlice& rhs_shape, \ - const BCast& broadcast_helper, \ - const std::vector& extend_dimensions) override { \ - xla::XlaBuilder* b = ctx->builder(); \ - (void)b; \ - return HLO; \ - } \ - }; \ +#define XLA_MAKE_BINARY(NAME, HLO) \ + class NAME##Op : public XlaBinaryOp { \ + public: \ + explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const absl::Span& lhs_shape, const xla::XlaOp& rhs, \ + const absl::Span& rhs_shape, \ + const BCast& broadcast_helper, \ + const std::vector& extend_dimensions) override { \ + xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ + return HLO; \ + } \ + }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); @@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +// Implementation of DivNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x / y; +// } +static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y)); + return result; +} +XLA_MAKE_BINARY(DivNoNan, + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); @@ -66,6 +85,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + if (DataTypeIsUnsigned(dtype)) { + return xla::Div(x, y); + } auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); @@ -81,6 +103,24 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); +} +XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Div(x, y)); +} +XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..696c1c39befd5aa2972afb6cfa64905b57a5ab72 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { +namespace { + +class BroadcastToOp : public XlaOpKernel { + public: + explicit BroadcastToOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + + OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), + errors::InvalidArgument( + "Input rank (", input_shape.dims(), + ") must be less than or equal to the output rank (", + output_shape.dims(), ")")); + + auto input_dims = input_shape.dim_sizes(); + auto output_dims = output_shape.dim_sizes(); + + // Broadcasting is done right-to-left on right-aligned dimensions; reverse + // the two vectors so elements to be broadcast are aligned. + absl::c_reverse(input_dims); + absl::c_reverse(output_dims); + + std::vector broadcast_dims; + std::vector broadcast_shape; + for (int i = 0; i < output_shape.dims(); ++i) { + if (i < input_shape.dims()) { + OP_REQUIRES( + context, + (output_dims[i] == 0 && input_dims[i] == 0) || + (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), + errors::InvalidArgument("invalid shape to broadcast from ", + input_shape.DebugString(), " to ", + output_shape.DebugString())); + + broadcast_dims.push_back(broadcast_shape.size()); + if (output_dims[i] == input_dims[i]) { + broadcast_shape.push_back(output_dims[i]); + } else if (output_dims[i] != input_dims[i]) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(input_dims[i]); + broadcast_shape.push_back(output_dims[i] / input_dims[i]); + } + } else { + broadcast_shape.push_back(output_dims[i]); + } + } + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::Reshape( + xla::BroadcastInDim(context->Input(0), + xla::ShapeUtil::MakeShape( + context->input_xla_type(0), broadcast_shape), + broadcast_dims), + output_shape.dim_sizes()); + context->SetOutput(0, output); + } +}; + +REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"), + BroadcastToOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f4106051043859a6786705009d76b02a64cd3ff1..0ae23aa6dfe49048ac5cb8ae00c12432b2e2a2fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector input_data; + std::vector partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9a1be494066e4f935a1d818bc86c86333e34fae --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -0,0 +1,509 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D convolution. + +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +// Returns the expanded size of a filter used for depthwise convolution. +// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. +xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { + int num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); // Crash OK + xla::Shape expanded_shape = shape; + expanded_shape.set_dimensions( + num_dims - 1, + shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1)); + return expanded_shape; +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tensor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, + xla::XlaBuilder* builder) { + xla::Shape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = + filter_shape.dimensions(filter_shape.dimensions_size() - 1); + int64 input_feature = + filter_shape.dimensions(filter_shape.dimensions_size() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + std::vector expanded_feature_broadcast_dims( + expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dimensions_size() - 2}); +} + +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dimensions_size() - 2; + int64 output_feature_dim = filter_shape.dimensions_size() - 1; + int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); + int64 input_feature = filter_shape.dimensions(input_feature_dim); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + xla::Shape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dimensions( + output_feature_dim, depthwise_multiplier * input_feature); + return xla::Reshape( + filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions())); +} + +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. +xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { + auto masked_expanded_filter = + xla::Select(CreateExpandedFilterMask(filter_shape, builder), + filter_backprop, xla::ZerosLike(filter_backprop)); + + auto elem_type = filter_shape.element_type(); + return xla::Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the select above guarantees + // that only one element is non zero, so there cannot be accumulated + // precision error. + xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type), + CreateScalarAddComputation(elem_type, builder), + {filter_shape.dimensions_size() - 2}), + xla::AsInt64Slice(filter_shape.dimensions())); +} + +// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA +// convolutions (as currently implemented). +Status CheckConvAttrs(const ConvOpAttrs& attrs) { + const int num_dims = attrs.num_spatial_dims + 2; + if (attrs.strides.size() != num_dims) { + return errors::InvalidArgument("Sliding window strides field must specify ", + num_dims, " dimensions"); + } + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not yet support strides in the batch and " + "depth dimensions."); + } + if (attrs.dilations.size() != num_dims) { + return errors::InvalidArgument("Dilations field must specify ", num_dims, + " dimensions"); + } + if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not support dilations in the batch and " + "depth dimensions."); + } + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + if (attrs.dilations[input_dim] < 1) { + return errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + attrs.dilations[input_dim]); + } + } + return Status::OK(); +} + +// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes +// to TensorShapes. +Status ConvBackpropComputeDimensionsV2XlaShapes( + StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, + const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, + absl::Span dilations, const std::vector& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + TensorShape input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); + return ConvBackpropComputeDimensionsV2( + label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape, dilations, strides, padding, data_format, + dims); +} + +} // anonymous namespace + +xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, + bool depthwise, + OpKernelConstruction* ctx) { + ConvOpAttrs attrs; + attrs.num_spatial_dims = num_spatial_dims; + attrs.depthwise = depthwise; + TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); + TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); + TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + + string data_format; + TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); + if (!FormatFromString(data_format, &attrs.data_format)) { + return errors::InvalidArgument("Invalid data format: ", data_format); + } + + return attrs; +} + +xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = conv_input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input)); + // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + + // For 2D convolution, there should be 4 dimensions. + int num_dims = attrs.num_spatial_dims + 2; + if (input_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString()); + } + if (filter_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument( + "filter must be ", num_dims, + "-dimensional: ", filter_shape.DebugString()); + } + + // The last two dimensions of the filter are the input and output shapes. + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); + // The 'C' dimension for input is in_depth. It must be the same as + // the filter's in_depth. + if (in_depth != input_shape.dimensions(feature_dim)) { + return errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, " vs ", + input_shape.dimensions(feature_dim)); + } + + if (attrs.depthwise) { + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); + } + + xla::ConvolutionDimensionNumbers dims; + std::vector window_strides(attrs.num_spatial_dims); + std::vector lhs_dilation(attrs.num_spatial_dims, 1); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector> padding(attrs.num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims); + dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = attrs.strides.at(dim); + rhs_dilation[i] = attrs.dilations.at(dim); + + int64 unused_output_size; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + input_shape.dimensions(dim), filter_shape.dimensions(i), + rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + return xla::ConvGeneralDilated( + conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, + dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); +} + +xla::StatusOr MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + int num_dims = attrs.num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + auto* builder = filter.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(out_backprop)); + + xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, + out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, + attrs.data_format, &dims)); + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); + + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1); + dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims); + + std::vector kernel_spatial_dims(attrs.num_spatial_dims); + std::vector> padding(attrs.num_spatial_dims); + std::vector lhs_dilation(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = attrs.dilations[dim]; + } + + // Mirror the filter in the spatial dimensions. + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + return xla::ConvGeneralDilated( + out_backprop, mirrored_weights, /*window_strides=*/ones, padding, + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / + filter_shape.dimensions(attrs.num_spatial_dims + 1) + : 1); +} + +xla::StatusOr MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = activations.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape activations_shape, + builder->GetShape(activations)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(gradients)); + const xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // The last two dimensions of the filter are the input and output shapes. + int num_dims = attrs.num_spatial_dims + 2; + int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + std::vector> padding(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector window_strides(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64 pad_before = + attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; + + padding[i] = {pad_before, pad_total - pad_before}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; + } + + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + auto filter_backprop = + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); + + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } + + return filter_backprop; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..6e1b70a47850ae5c05939f8dfb7ec129c031df21 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +// This header exposes utilities for translating TensorFlow convolution ops into +// XLA ops. +// +// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g. +// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in +// this header to implement a new and exciting convolution op, for example a +// fused TensorFlow op that contains a convolution and other things. + +namespace tensorflow { + +// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA +// convolution. +struct ConvOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static xla::StatusOr Create(int num_spatial_dims, bool depthwise, + OpKernelConstruction* ctx); + + bool depthwise; + int num_spatial_dims; + std::vector dilations; + std::vector strides; + Padding padding; + TensorFormat data_format; +}; + +// Creates a new XLA forward or backward convolution with the given inputs and +// attributes. +xla::StatusOr MakeXlaForwardConvOp(StringPiece type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 5da7972397b32fb4a2f216913e065c04131a3773..cd7c820be0b6029514ff74288e7bdd3f75b5d6b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,12 +15,17 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,271 +38,28 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { - namespace { -// Returns the expanded size of a filter used for depthwise convolution. -// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. -TensorShape ExpandedFilterShapeForDepthwiseConvolution( - const TensorShape& shape) { - int num_dims = shape.dims(); - CHECK_GE(num_dims, 2); - TensorShape expanded_shape = shape; - expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) * - shape.dim_size(num_dims - 1)); - return expanded_shape; -} - -// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return xla::Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); -} - -// Create a mask for depthwise convolution that will make a normal convolution -// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tensor -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 -// -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 -// -// and divide B it by 2 to get -// 0 0 1 1 2 2 -// -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. -xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. - expanded_feature_iota = - xla::Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); - - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); -} - -// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding -// zeros for the cross-depth filters. Used to build a depthwise convolution. -xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter, - xla::XlaBuilder* builder) { - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - - // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; - implicit_broadcast_filter_shape.set_dim( - implicit_broadcast_filter_shape.dims() - 2, 1); - implicit_broadcast_filter_shape.set_dim( - implicit_broadcast_filter_shape.dims() - 1, - depthwise_multiplier * input_feature); - auto implicit_broadcast_filter = - xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); - - // Broadcast the filter to [H, W, ..., M, M*N]. - auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); - - // If the filter mask is set, choose the broadcasted filter, othwerwise, - // choose zero. - return xla::Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); -} - -// Inverse of ExpandFilterForDepthwiseConvolution. -xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, - const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter_backprop, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - auto masked_expanded_filter = xla::Select( - CreateExpandedFilterMask(filter_shape, builder), filter_backprop, - CreateExpandedZero(filter_shape, dtype, builder)); - return xla::Reshape( - // This reduce does not need inputs to be converted with - // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with - // ExpandedZero guarantees that only one element is non zero, so there - // cannot be accumulated precision error. - xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), - filter_shape.dim_sizes()); -} - class ConvOp : public XlaOpKernel { public: explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape input_shape = ctx->InputShape(0); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, ..., in_depth, out_depth] - const TensorShape filter_shape = ctx->InputShape(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES( - ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("input must be ", num_dims(), "-dimensional", - input_shape.DebugString())); - OP_REQUIRES( - ctx, filter_shape.dims() == num_dims(), - errors::InvalidArgument("filter must be ", num_dims(), - "-dimensional: ", filter_shape.DebugString())); - - // The last two dimension of the filter are the input and output shapes. - const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); - - // The 'C' dimension for input is in_depth. It must be the same as - // the filter's in_depth. - OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", input_shape.dim_size(feature_dim))); - - xla::XlaBuilder* b = ctx->builder(); - - xla::XlaOp filter = ctx->Input(1); - TensorShape expanded_filter_shape = filter_shape; - if (depthwise_) { - filter = ExpandFilterForDepthwiseConvolution( - filter_shape, ctx->input_type(0), filter, b); - expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - } - - xla::ConvolutionDimensionNumbers dims; - std::vector window_strides(num_spatial_dims_); - std::vector lhs_dilation(num_spatial_dims_, 1); - std::vector rhs_dilation(num_spatial_dims_); - std::vector> padding(num_spatial_dims_); - - dims.set_input_batch_dimension(batch_dim); - dims.set_output_batch_dimension(batch_dim); - dims.set_input_feature_dimension(feature_dim); - dims.set_output_feature_dimension(feature_dim); - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_input_spatial_dimensions(dim); - dims.add_kernel_spatial_dimensions(i); - dims.add_output_spatial_dimensions(dim); - window_strides[i] = strides_.at(dim); - rhs_dilation[i] = dilations_.at(dim); - - int64 unused_output_size; - OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), expanded_filter_shape.dim_size(i), - rhs_dilation[i], window_strides[i], padding_, - &unused_output_size, &padding[i].first, &padding[i].second)); - } - - xla::XlaOp conv = - xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); - ctx->SetOutput(0, conv); + xla::StatusOr conv = MakeXlaForwardConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); + OP_REQUIRES_OK(ctx, conv.status()); + ctx->SetOutput(0, conv.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); @@ -329,127 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel { public: explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - TensorShape input_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - - const TensorShape filter_shape = ctx->InputShape(1); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - xla::XlaBuilder* b = ctx->builder(); - auto filter = ctx->Input(1); - auto out_backprop = ctx->Input(2); - - // The input gradients are computed by a convolution of the output - // gradients and the filter, with some appropriate padding. See the - // comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(batch_dim); - dnums.set_output_batch_dimension(batch_dim); - dnums.set_input_feature_dimension(feature_dim); - dnums.set_output_feature_dimension(feature_dim); - - // TF filter shape is [ H, W, ..., inC, outC ] - // Transpose the input and output features for computing the gradient. - dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); - dnums.set_kernel_output_feature_dimension(num_spatial_dims_); - - std::vector kernel_spatial_dims(num_spatial_dims_); - std::vector> padding(num_spatial_dims_); - std::vector lhs_dilation(num_spatial_dims_); - std::vector rhs_dilation(num_spatial_dims_); - std::vector ones(num_spatial_dims_, 1); - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(i); - dnums.add_output_spatial_dimensions(dim); - - kernel_spatial_dims[i] = i; - padding[i] = {dims.spatial_dims[i].pad_before, - dims.spatial_dims[i].pad_after}; - lhs_dilation[i] = dims.spatial_dims[i].stride; - rhs_dilation[i] = dilations_[dim]; - } - - // If this is a depthwise convolution, expand the filter. - if (depthwise_) { - filter = ExpandFilterForDepthwiseConvolution( - filter_shape, ctx->input_type(1), filter, b); - } - - // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); - - // activation gradients - // = gradients (with padding and dilation) mirrored_weights - xla::XlaOp in_backprop = xla::ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums); - - ctx->SetOutput(0, in_backprop); + TensorShape input_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); + xla::Shape input_shape = + TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); + + xla::StatusOr in_backprop = + MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, + ctx->Input(1), ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, in_backprop.status()); + ctx->SetOutput(0, in_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); @@ -486,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel { public: explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - - OP_REQUIRES( - ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape activations_shape = ctx->InputShape(0); - TensorShape filter_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp activations = ctx->Input(0); - xla::XlaOp gradients = ctx->Input(2); - - // The filter gradients are computed by a convolution of the input - // activations and the output gradients, with some appropriate padding. - // See the comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); - - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); - - std::vector> padding(num_spatial_dims_); - std::vector rhs_dilation(num_spatial_dims_); - std::vector window_strides(num_spatial_dims_); - std::vector ones(num_spatial_dims_, 1); - - // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_output_spatial_dimensions(i); - } - dnums.set_output_batch_dimension(num_spatial_dims_); - dnums.set_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(dim); - - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - // - const int64 padded_in_size = - dims.spatial_dims[i].expanded_output_size + - (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; - - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; - - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int64 pad_before = - padding_ == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = dilations_[dim]; - } - - // Besides padding the input, we will also expand output_rows to - // expanded_out_rows = (output_rows - 1) * stride + 1 - // with zeros in between: - // - // a . . . b . . . c . . . d . . . e - // - // This is done by specifying the window dilation factors in the - // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (depthwise_) { - filter_backprop = ContractFilterForDepthwiseBackprop( - ctx, filter_shape, ctx->input_type(0), filter_backprop, b); - } - ctx->SetOutput(0, filter_backprop); + TensorShape filter_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape)); + xla::Shape filter_shape = + TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); + + xla::StatusOr filter_backprop = MakeXlaBackpropFilterConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, + ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, filter_backprop.status()); + ctx->SetOutput(0, filter_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index a5b870f8dbf70bcee331992345d63fd5d986bdca..6653944a911588b7bc88d67b8cdd2c17850530f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel { // in the XLA documentation. virtual xla::XlaOp Computation( XlaOpKernelContext* ctx, const xla::XlaOp& lhs, - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, - const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const absl::Span& lhs_shape, const xla::XlaOp& rhs, + const absl::Span& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 12b0e38288e8f222ed506a75ec2575f27141c859..e96a1adce43c750314715107b4a1954d4a5b4e40 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -48,7 +48,7 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ed44ad218b6dc073583ec339da082b6881ad672d..49c12fc232092873b69961644a059abc6035f64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -29,7 +29,7 @@ namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, - gtl::ArraySlice other_dims, + absl::Span other_dims, xla::PrimitiveType element_type) { xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: @@ -177,8 +177,8 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); - tensorflow::gtl::ArraySlice other_dims(dims); - other_dims.pop_back(); + absl::Span other_dims(dims); + other_dims.remove_suffix(1); xla::XlaOp input = ctx->Input(0); xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index a3389d5b905bf3ee15744ab4fcee193d312e2ae0..4af1e8b44cbbd02d8e3ea5e42d841c92288b5d56 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - VLOG(3) << "DynamicUpdateSliceOp::Compile"; + DataType index_type = ctx->InputType("indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); - DataType index_type = input_type(2); - OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape update_shape = ctx->InputShape(1); - const TensorShape index_shape = ctx->InputShape(2); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape update_shape = ctx->InputShape("update"); + const TensorShape index_shape = ctx->InputShape("indices"); OP_REQUIRES( ctx, @@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = - xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = xla::DynamicUpdateSlice( + ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); ctx->SetOutput(0, result); } }; REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); +class DynamicSliceOp : public XlaOpKernel { + public: + explicit DynamicSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType index_type = ctx->InputType("start_indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); + CHECK(index_type == ctx->InputType("size_indices")); + + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape start_indices_shape = ctx->InputShape("start_indices"); + const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(start_indices_shape) && + start_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "start_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and start_indices has shape ", + start_indices_shape.DebugString())); + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(size_indices_shape) && + size_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "size_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and size_indices has shape ", + size_indices_shape.DebugString())); + + std::vector size_indices; + OP_REQUIRES_OK( + ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + xla::XlaOp result = xla::DynamicSlice( + ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"), + DynamicSliceOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 35de96e0aab847fa39ef26d5f3052c392062fd7d..44140304fdf5cdf60d8ad8b85c532fcadff8ba86 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // gather = s32[3,2] gather(operand, indices), - // output_window_dims={0}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3, 1} + // slice_sizes={3, 1} // // // Example of an N-D gather pulling out slices of shape [1,1,2] out of a @@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3,2] parameter(0) // indices = s32[2,2] parameter(1) // gather = s32[2,2] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0,1}, - // gather_dims_to_operand_dims={0,1}, + // offset_dims={1}, + // collapsed_slice_dims={0,1}, + // start_index_map={0,1}, // index_vector_dim=0, - // window_bounds={1,1,2} + // slice_sizes={1,1,2} xla::GatherDimensionNumbers dim_numbers; - std::vector window_bounds; - window_bounds.reserve(input_shape.dims()); + std::vector slice_sizes; + slice_sizes.reserve(input_shape.dims()); for (int64 i = 0; i < input_shape.dims(); i++) { int64 window_bound; if (axis <= i && i < (axis + num_index_dims)) { - dim_numbers.add_elided_window_dims(i); + dim_numbers.add_collapsed_slice_dims(i); window_bound = 1; } else { window_bound = input_shape.dim_size(i); } - window_bounds.push_back(window_bound); + slice_sizes.push_back(window_bound); if (i < axis) { - dim_numbers.add_output_window_dims(i); + dim_numbers.add_offset_dims(i); } else if (i >= (axis + num_index_dims)) { int64 indices_rank = indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims(); - dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims); + dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); } } dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims()); for (int64 i = axis; i < axis + num_index_dims; i++) { - dim_numbers.add_gather_dims_to_operand_dims(i); + dim_numbers.add_start_index_map(i); } - *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index e72200bfbcff20c55ac03030f1afc4bacaabf7ce..19dd38c46ef154ea74bcbb6721dd04924702efcc 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -25,7 +25,10 @@ class IdentityOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - ctx->SetOutput(i, ctx->Input(i)); + // Forwards using the underlying op_kernel_context so both tensor and + // resource values are forwarded correctly. + ctx->op_kernel_context()->set_output(i, + ctx->op_kernel_context()->input(i)); } } @@ -35,9 +38,10 @@ class IdentityOp : public XlaOpKernel { // XLA_* devices also register a "real" Identity operator so we suppress the // dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); - -REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(), + IdentityOp); +REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(), + IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index ceb2af756c2d2020c7449086b957c9fbc1cc2979..56da50f140893c68c8a1556853884720b21c7229 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -200,25 +216,36 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp outputs = xla::Conditional( - ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, - xla::Tuple(b, inputs), *else_result.computation); + auto input_tuple = xla::Tuple(b, inputs); + xla::XlaOp outputs = + xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation, + input_tuple, *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { - if (ctx->input_type(i) != DT_RESOURCE) { - xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); - if (VLOG_IS_ON(2)) { - LOG(INFO) << "Setting output " << i; - auto shape_or = b->GetShape(output_handle); - if (shape_or.ok()) { - LOG(INFO) << "Shape for output " << i << ": " - << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); - } else { - LOG(INFO) << "Shape unknown for output " << i; - } + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; } - ctx->SetOutput(i, output_handle); } + ctx->SetOutput(i, output_handle); + } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); } // Updates the values of any resource variables modified by the conditional @@ -247,6 +274,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index f9bc98a198a72dcc0594e61971713bf890ce30b6..7783e13a8a5dacc1901392703687230020f82483 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel { DataType cond_type_; DataTypeVector input_types_; DataTypeVector output_types_; + bool has_token_input_output_; + std::vector token_input_nodes_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 33a73fe5fdf403e513be085dd7bcea3255277b4a..921b4340c0ac674a5ad7d17aaf54f1cf36975151 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + OP_REQUIRES(context, output_size <= kint32max, + errors::InvalidArgument("Need output_size <= kint32Max, got ", + output_size)); xla::XlaOp score_thresh = context->Input("score_threshold"); xla::XlaOp iou_thresh = context->Input("iou_threshold"); @@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); - // num_valid is scalar. - xla::XlaOp num_valid = xla::Reduce( + // num_valid is scalar. Value should be bound by output_size. + xla::XlaOp num_valid_total = xla::Reduce( ones_included, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); + xla::XlaOp num_valid = + xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); xla::XlaOp output_tuple = TopK(scores_included, output_size); xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8d75624e74028ea083c3facc4f9578ec14c50e6d..7b2bb4a7c50fc954237e09a32f71009f790b60d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -32,13 +33,13 @@ namespace { // // 1. S := (N - 1) / gcd(N-1, R-1) // 2. k := (R - 1) / gcd(N-1, R-1) -// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1) // // For example, to Scale from 7x7 -> 15x15: // // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 -// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2) // // // The 7x7 -> 15x15 case is much too large to write out in full as an @@ -65,6 +66,8 @@ namespace { // 1/9 * 3 6 9 6 3 // 2 4 6 4 2 // 1 2 3 2 1 +// Note that the convolution kernel matrix is separable and thus we can instead +// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis. // Computes the size of the convolutional kernel and stride to use when resizing // from in_size to out_size. @@ -76,7 +79,8 @@ struct ResizeConvolutionDims { std::vector stride; }; ResizeConvolutionDims ComputeResizeConvolutionParameters( - gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + absl::Span in_size, absl::Span out_size, + bool align_corners) { CHECK_EQ(in_size.size(), out_size.size()); int num_spatial_dims = in_size.size(); ResizeConvolutionDims dims; @@ -92,15 +96,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // entry before resizing. dims.stride[i] = dims.kernel_size[i] = 1; } else { - int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), - static_cast(out_size[i] - 1)); - dims.stride[i] = (in_size[i] - 1) / gcd; - dims.kernel_size[i] = (out_size[i] - 1) / gcd; + // The scaling factor changes depending on the alignment of corners. + const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i]; + const int64 out_size_factor = + align_corners ? out_size[i] - 1 : out_size[i]; + + int64 gcd = MathUtil::GCD(static_cast(in_size_factor), + static_cast(out_size_factor)); + dims.stride[i] = in_size_factor / gcd; + dims.kernel_size[i] = out_size_factor / gcd; } } return dims; } +// The upper padding of the input needed by ConvGeneralDilated calls is +// determined by solving two related relationships (assuming rhs_dilation == 0): +// 1. dilated_input_dim = lower_padding + upper_padding +// + lhs_dilation * (in_size - 1) + 1 +// 2. dilated_input_dim = (2 * dims.kernel-size - 1) +// + dims.stride * (out_size - 1) +int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, + int64 stride) { + return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - + 1 - (kernel_size * (in_size - 1)); +} + // Form a 2D convolution kernel like: // 1 2 3 2 1 // 2 4 6 4 2 @@ -112,14 +133,14 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -std::vector Make1DKernel(int64 n) { +xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return kernel; + return xla::ConstantR1(builder, kernel); } // Kernels with more than 16 spatial elements are considered intense and the @@ -127,43 +148,28 @@ std::vector Make1DKernel(int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + auto depthwise_kernel = xla::Broadcast( + xla::Zero(builder, xla::F32), + {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); - auto diag = xla::ConvertElementType( - xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, - 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); return xla::Mul( - xla::Mul(diag, - xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), + Make1DKernel(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); - - auto diag = xla::ConvertElementType( - xla::Eq( - xla::Broadcast(channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); - if (dim == 1) { - return xla::Mul( - diag, xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}); - } - return xla::Mul(diag, - xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + auto depthwise_kernel = + xla::Broadcast(xla::Zero(builder, xla::F32), + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); + return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + /*broadcast_dimensions=*/{dim}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -171,7 +177,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector in_size, std::vector out_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -185,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { dimension_numbers.add_input_spatial_dimensions(1 + i); dimension_numbers.add_output_spatial_dimensions(1 + i); @@ -196,37 +203,95 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, out_size); + ComputeResizeConvolutionParameters(in_size, out_size, align_corners); xla::XlaOp output; - // Split convolutions into independent dimensions if they wmuld be a very + + // Concatenation and padding below currently assumes num_spatial_dims is 2 to + // prevent needless code complexity. + CHECK_EQ(num_spatial_dims, 2) + << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently."; + std::vector upper_padding(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + upper_padding[i] = dims.kernel_size[i] - 1; + } + xla::XlaOp input_data = input; + + if (!align_corners) { + // When Tensorflow does not align_corners, the resize indexing can access + // beyond the upper bound and is instead clamped to prevent out of bounds + // reads. This is conceptually the same as extending the edges of the input. + // We emulate this by copying the last row/column of the input. + // Calculate what padding would be needed then determine how far to extend + // the border before lhs dilation. + std::vector num_extended(num_spatial_dims); + upper_padding[0] = CalculateUpperPadding( + in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = CalculateUpperPadding( + in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]); + num_extended[0] = upper_padding[0] / (dims.kernel_size[0]); + num_extended[1] = upper_padding[1] / (dims.kernel_size[1]); + + if (num_extended[0] > 0) { + auto slice = + xla::Slice(input_data, {0, in_size[0] - 1, 0, 0}, + {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + for (int i = 0; i < num_extended[0]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 1); + } + } + + if (num_extended[1] > 0) { + auto slice = + xla::Slice(input_data, {0, 0, in_size[1] - 1, 0}, + {1, in_size[0] + num_extended[0], in_size[1], channels}, + {1, 1, 1, 1}); + for (int i = 0; i < num_extended[1]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 2); + } + } + + // Setting in_size to (in_size + num_extended) due to the above Slice and + // ConcatInDim. Recalculate needed padding after the above Slice/Concat. + upper_padding[0] = + CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0], + dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = + CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1], + dims.kernel_size[1], dims.stride[1]); + } + + // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = xla::ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = + xla::ConvGeneralDilated(input_data, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, upper_padding[0]}, + {dims.kernel_size[1] - 1, upper_padding[1]}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); output = xla::ConvGeneralDilated( - input, kernel0, {dims.stride[0], 1}, + input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ - {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // Add broadcasts to handle expanding from a size == 1 dimension to a @@ -245,24 +310,25 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector in_size, std::vector grad_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, grad_size); + ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); // To form the backward convolution, we keep the kernel unchanged (it is // already symmetric) and swap the roles of strides and LHS dilation. xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(1 + i); - dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_input_spatial_dimensions(i + 1); + dimension_numbers.add_output_spatial_dimensions(i + 1); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = @@ -285,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -311,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/{1, dims.stride[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. @@ -341,10 +410,6 @@ class ResizeBilinearOp : public XlaOpKernel { public: explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); - OP_REQUIRES( - ctx, align_corners_ == true, - errors::Unimplemented( - "ResizeBilinear with align_corners=False is not yet implemented")); } void Compile(XlaOpKernelContext* ctx) override { @@ -377,20 +442,19 @@ class ResizeBilinearOp : public XlaOpKernel { // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. - std::vector slice_size = in_size; bool slice_input = false; for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] > 1 && out_size[i] == 1) { // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first // entry before resizing. slice_input = true; - slice_size[i] = 1; + in_size[i] = 1; } } if (slice_input) { - input = xla::Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = + xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } // Output is always type float. @@ -406,6 +470,9 @@ class ResizeBilinearOp : public XlaOpKernel { // operations along different dimensions. // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. // // This makes the convolutions kernels smaller and the operation faster. xla::XlaOp output = input; @@ -415,21 +482,24 @@ class ResizeBilinearOp : public XlaOpKernel { (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1) { + k[0] > 1 && k[1] > 1 && align_corners_) { std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, next_out_size, + channels, align_corners_); input = output; in_size = next_out_size; } else { - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, out_size, + channels, align_corners_); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels); + in_size, out_size, channels, + align_corners_); in_size = out_size; } } @@ -509,17 +579,20 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels); + b, grad, num_spatial_dims, in_size, next_grad_size, channels, + align_corners_); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 22a45b2a11e8ecb688f8e773ef4b286eafe68f4f..3d81ae9eb89a80e5b89b180ad77521c5ed15e79d 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2..2a42eeaf76ab3aa88ff3a93ef7eb7ab217964bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -29,7 +29,14 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, const xla::LiteralSlice& pad_literal, + const MirrorPadMode mode, xla::XlaBuilder* b) { + // The difference in the semantics of REFLECT and SYMMETRIC is that REFLECT + // will not mirror the border values while symmetric does. + // e.g. input is [1, 2, 3] and paddings is [0, 2], then the output is: + // - [1, 2, 3, 2, 1] in reflect mode + // - [1, 2, 3, 3, 2] in symmetric mode. + int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { @@ -39,9 +46,19 @@ class MirrorPadOp : public XlaOpKernel { TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + + // Padding amounts on each side must be no more than the size of the + // original shape. + TF_RET_CHECK(lhs_padding >= 0 && + lhs_padding <= dim_size - excluded_edges); + TF_RET_CHECK(rhs_padding >= 0 && + rhs_padding <= dim_size - excluded_edges); + + auto lhs_pad = + xla::SliceInDim(t_rev, dim_size - excluded_edges - lhs_padding, + dim_size - excluded_edges, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, excluded_edges, + excluded_edges + rhs_padding, 1, dimno); accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; @@ -53,9 +70,10 @@ class MirrorPadOp : public XlaOpKernel { MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); - OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT, - xla::Unimplemented( - "Only REFLECT MirrorPad mode is currently supported")); + OP_REQUIRES( + ctx, mode == MirrorPadMode::REFLECT || mode == MirrorPadMode::SYMMETRIC, + xla::Unimplemented("Unsupported MirrorPad mode. Only SYMMETRIC and " + "REFLECT modes are currently supported")); const int dims = input_shape.dims(); OP_REQUIRES( @@ -83,7 +101,7 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = - DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, mode, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 3d506e71e03d6b804d1ea0e63c760cfb82629f12..27690c156e4da129ad139f3880bba3a208b5606d 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } - // Method that builds an initial value to use in reductions. - virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; - - // The reduction operation to apply to each window. - virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; - - // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) = 0; - - void Compile(XlaOpKernelContext* ctx) override { - std::vector ksize = ksize_; - std::vector stride = stride_; - if (ctx->num_inputs() != 1) { - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), - errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString())); - OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - ksize.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); - - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), - errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString())); - OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - stride.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + protected: + xla::StatusOr> GetKernelSize(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return ksize_; } - const TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("Input to ", type_string(), - " operator must have ", num_dims(), - " dimensions")); + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector ksize; + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } + return ksize; + } - xla::XlaBuilder* const b = ctx->builder(); - auto input = - XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, - stride, padding_); - auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - ctx->SetOutput(0, - PostProcessOutput(ctx, pooled, input_type(0), input_shape)); + xla::StatusOr> GetStride(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return stride_; + } + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector stride; + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } + return stride; } protected: @@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel { xla::PrimitiveType xla_reduction_type_; }; +// Converts the tensor data format to the one required by the XLA pooling +// library. +xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, + int num_spatial_dims) { + int num_dims = num_spatial_dims + 2; + int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); + absl::InlinedVector spatial_dimensions(num_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { + spatial_dimensions[spatial_dim] = + GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); + } + return xla::TensorFormat(/*batch_dimension=*/batch_dimension, + /*feature_dimension=*/feature_dimension, + /*spatial_dimensions=*/spatial_dimensions); +} + class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return xla::MinValue(b, xla_reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateMax(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return output; + auto pooling = + xla::MaxPool(ctx->Input(0), ksize, stride, padding_, + XlaTensorFormat(data_format_, input_shape.dims() - 2)); + ctx->SetOutput(0, pooling); } }; @@ -180,60 +199,6 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Common computation shared between AvgPool and AvgPoolGrad. Divide each -// element of an image by the count of elements that contributed to that -// element during pooling. -static xla::XlaOp AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape, xla::Padding padding, - const std::vector& ksize, const std::vector& stride, - int num_spatial_dims, TensorFormat data_format) { - if (padding == xla::Padding::kValid) { - // In VALID padding, all windows have the same number of elements - // contributing to each average. Divide by the window size everywhere to - // get the average. - int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, - [](int64 a, int64 b) { return a * b; }); - - auto divisor = - XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return xla::Div(output, divisor); - } else { - // For SAME padding, the padding shouldn't be included in the - // counts. We use another ReduceWindow to find the right counts. - - // TODO(phawkins): use a less brute-force way to compute this. Only - // the boundary regions will have interesting values here. - - std::vector input_dim_sizes(num_spatial_dims); - std::vector window_dims(num_spatial_dims); - std::vector window_ksize(num_spatial_dims); - std::vector window_stride(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); - input_dim_sizes[i] = input_shape.dim_size(dim); - window_dims[i] = dim; - window_ksize[i] = ksize[dim]; - window_stride[i] = stride[dim]; - } - - // Build a matrix of all 1s, with the same width/height as the input. - const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = xla::Broadcast( - XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); - - // Perform a ReduceWindow with the same window size, strides, and padding - // to count the number of contributions to each result element. - auto reduce = xla::ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, - xla::Padding::kSame); - auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - - return xla::Div(output, counts, window_dims); - } -} - class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) @@ -241,20 +206,34 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return xla::Zero(b, xla_reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateAdd(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, num_spatial_dims_, - data_format_); + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); + + auto xla_data_format = + XlaTensorFormat(data_format_, input_shape.dims() - 2); + auto spatial_padding = MakeSpatialPadding( + input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); + + // Convert the input to the reduction type. + auto converted_input = + ConvertElementType(ctx->Input(0), xla_reduction_type_); + auto pooling = + xla::AvgPool(converted_input, ksize, stride, spatial_padding, + xla_data_format, padding_ == xla::Padding::kValid); + // Convert the pooling result back to the input type before returning it. + ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); } }; @@ -431,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel { errors::InvalidArgument("out_backprop must be ", num_dims(), "-dimensional")); - int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - int64 depth = out_backprop_shape.dim_size(depth_dim); - - // We can think of average-pooling as: - // * a convolution with a kernel consisting entirely of 1s, where the - // input feature and output feature are equal, and 0s everywhere else. - // * followed by dividing by the counts. - // - // This then gives us an algorithm to build the gradient: - // * divide out_backprop by the counts, followed by - // * Conv2DBackpropInput specialized for that kernel, which simplifies to - // a Pad and a ReduceWindow. - // - // For an explanation of backpropagation for convolution, see the comments - // in third_party/tensorflow/core/kernels/conv_grad_ops.h - - // TF filter shape is [ H, W, ..., inC, outC ] - std::vector filter_dims(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - filter_dims[i] = ksize_[dim]; - } - filter_dims[num_dims() - 2] = depth; - filter_dims[num_dims() - 1] = depth; - TensorShape filter_shape(filter_dims); - - // Reuse the logic from Conv2DBackpropInput to compute padding. - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - ctx, ConvBackpropComputeDimensions( - type_string(), /*num_spatial_dims=*/num_spatial_dims_, - gradients_shape, filter_shape, out_backprop_shape, stride_, - padding_, data_format_, &dims)); - - // The input gradients are computed by a convolution of the output gradients - // and the filter, with some appropriate padding. See the comment at the top - // of conv_grad_ops.h for details. - xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - auto dtype = input_type(1); + std::vector stride_int64s(stride_.begin(), stride_.end()); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - // Divide the out_backprop values by the counts for each spatial position. - std::vector stride_int64s(stride_.begin(), stride_.end()); - auto out_backprop_div = AvgPoolDivideByCount( - ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, - stride_int64s, num_spatial_dims_, data_format_); - - // Pad the gradients in the spatial dimensions. We use the same padding - // as Conv2DBackpropInput. - xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - auto* padding = padding_config.mutable_dimensions(dim); - padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); - padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); - padding->set_interior_padding(dims.spatial_dims[i].stride - 1); - } - - auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); - - // in_backprop = padded_gradients ones - std::vector ones(num_dims(), 1LL); - auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = xla::ReduceWindow( - XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), ksize_, - /* window_strides=*/ones, xla::Padding::kValid); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); + xla::PrimitiveType xla_reduction_type; + auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type)); + auto converted_out_backprop = + xla::ConvertElementType(out_backprop, xla_reduction_type); + auto xla_data_format = + XlaTensorFormat(data_format_, gradients_shape.dims() - 2); + auto padding_values = + MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s, + xla_padding, xla_data_format); + auto in_backprop = + xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(), + ksize_, stride_int64s, padding_values, xla_data_format, + /*counts_include_padding=*/padding_ == VALID); + // Convert the pooling result back to the input type before returning it. + xla::PrimitiveType xla_out_backprop_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), + &xla_out_backprop_type)); + ctx->SetOutput(0, + xla::ConvertElementType(in_backprop, xla_out_backprop_type)); } protected: diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index de9068a640dc03b141b6954eaa1629dd6c8c1f3a..7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -23,15 +23,10 @@ namespace { class QROp : public XlaOpKernel { public: explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - bool full_matrices; - OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices)); - OP_REQUIRES( - ctx, full_matrices, - errors::Unimplemented("full_matrices=False case of QR decomposition is " - "not implemented in TF/XLA")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0)); + auto result = QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; @@ -39,6 +34,11 @@ class QROp : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie().q); ctx->SetOutput(1, result.ValueOrDie().r); } + + private: + // If true, compute full-sized q and r. If false, compute only the leading P + // columns of q. + bool full_matrices_; }; REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2da9340625db08b14b78340c471f096baf15689d..afd5986846705f66eb4c7ced9dbe2f4757f5af7f 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp indices = xla::Iota(builder, xla::S32, n); // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto swap_body_fn = [&](xla::XlaOp i, + absl::Span loop_vars, xla::XlaBuilder* builder) -> xla::StatusOr> { auto swaps = loop_vars[0]; diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index b11a4ce36da9907ce8fe377c075023a4540797fa..8102faad28db71075fb8da269c55edbdb667193e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel { explicit ReduceWindowOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_dimensions", &window_dimensions_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_strides", &window_strides_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); const DataType dtype = context->input_type(0); + std::vector window_dimensions; + std::vector window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + const int rank = input_shape.dims(); - OP_REQUIRES(context, rank == window_dimensions_.size(), + OP_REQUIRES(context, rank == window_dimensions.size(), errors::InvalidArgument( "The size of window_dimensions must be equal to the input " "rank (", - window_dimensions_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == window_strides_.size(), + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), errors::InvalidArgument( "The size of window_strides must be equal to the input " "rank (", - window_strides_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_low_.size(), - errors::InvalidArgument( - "The size of padding_low must be equal to the input " - "rank (", - padding_low_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_high_.size(), - errors::InvalidArgument( - "The size of padding_high must be equal to the input " - "rank (", - padding_high_.size(), " vs. ", rank, ")")); - - xla::XlaBuilder* builder = context->builder(); + window_strides.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel { compile_options.use_tuple_arg = false; compile_options.resolve_compile_time_constants = false; compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; XlaCompiler::CompilationResult reducer; OP_REQUIRES_OK(context, context->compiler()->CompileFunction( compile_options, *computation_, @@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel { xla::Shape scalar_shape; OP_REQUIRES_OK(context, TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); OP_REQUIRES(context, - xla::ShapeUtil::Compatible( - reducer.xla_output_shape, - xla::ShapeUtil::MakeTupleShape({scalar_shape})), + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, errors::InvalidArgument( - "Invalid output shape of ReduceWindow reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", - xla::ShapeUtil::HumanString(reducer.xla_output_shape))); - - // Wraps the reducer in a computation that unpacks the output tuple. - xla::XlaComputation wrapper; - { - std::unique_ptr cb = - builder->CreateSubBuilder("wrapper"); - auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); - auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); - auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); - xla::GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); - OP_REQUIRES_OK(context, result.status()); - wrapper = std::move(result.ValueOrDie()); - } - - std::vector> padding(rank); - for (int i = 0; i < rank; ++i) { - padding[i] = {padding_low_[i], padding_high_[i]}; + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; } xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( - context->Input(0), context->Input(1), wrapper, window_dimensions_, - window_strides_, padding); + context->Input(0), context->Input(1), *reducer.computation, + window_dimensions, window_strides, padding); context->SetOutput(0, output); } private: const NameAttrList* computation_; - std::vector window_dimensions_; - std::vector window_strides_; - std::vector padding_low_; - std::vector padding_high_; TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); }; -REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); +REGISTER_XLA_OP(Name("XlaReduceWindow") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + ReduceWindowOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index b52f0a0ab6290f2019bb58120be5c2364ec15bb6..118f2798d559f43acb7f6394a7337426164325ef 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reduction Ops. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -29,9 +30,6 @@ namespace tensorflow { XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type) : XlaOpKernel(ctx), reduction_type_(reduction_type) { - const DataType dt = BaseType(input_type(0)); - OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); - OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); OP_REQUIRES_OK( ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); @@ -58,20 +56,24 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { return; } + OP_REQUIRES(ctx, axes_tensor_shape.dims() <= 1, + errors::InvalidArgument( + "Expected scalar or vector as index argument, got ", + axes_tensor_shape.DebugString())); + // Evaluate the constant, reshaping to a 1-vector if it is a scalar. + std::vector axes; xla::Literal axes_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, - &axes_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << axes_literal.ToString(); + VLOG(1) << "axes : " << absl::StrJoin(axes, ","); - gtl::InlinedVector bitmap(data_shape.dims(), false); + absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { - int32 index = axes_literal.Get({i}); + int64 index = axes[i]; OP_REQUIRES(ctx, !(index < -data_shape.dims() || index >= data_shape.dims()), errors::InvalidArgument("Invalid reduction dimension (", index, @@ -101,7 +103,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(absl::StrCat(desc, "-reduction")); xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 121750a82a8c5cbe940068555ad273b7e0d22dfc..366ce42866e9f1375ee0ff6f4985c8f461fc0885 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel { sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + std::vector shape_input; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one if there @@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = literal.Get({d}); + const int32 size = shape_input[d]; if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 1911e6ea362f999c787cbf95dcc9137a6a630273..e172c649325adb6f7761ce0be141f21e8d545bc1 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel { } else { xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + DataType input_type = ctx->input_type(0); + XlaContext& tc = XlaContext::Get(ctx); + + if (input_type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + ctx->SetStatus(tc.AddResourceRetval(index_, resource)); + return; + } auto is_constant = ctx->builder()->IsConstant(input); if (!is_constant.ok()) { @@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel { return; } - XlaContext& tc = XlaContext::Get(ctx); if (tc.resolve_compile_time_constants() && (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; @@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval"), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index d962ef4a5f53470838643541f8a1e693d2f4011c..8494864b33a44b03a07e3fea7766285f54074e7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -95,10 +95,24 @@ class ReverseV2Op : public XlaOpKernel { std::vector axes; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); + // witnessed_axes is used to ensure that the same axis is not marked to be + // reversed multiple times. + absl::InlinedVector witnessed_axes(x_shape.dims(), false); + for (int d = 0; d < axes.size(); ++d) { - OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()), - errors::InvalidArgument(axes[d], " is out of range [0, ", - x_shape.dims(), ").")); + OP_REQUIRES( + ctx, (-x_shape.dims() <= axes[d]) && (axes[d] < x_shape.dims()), + errors::InvalidArgument(axes[d], " is out of range [-", + x_shape.dims(), ", ", x_shape.dims(), ").")); + // Axes can be negative and are shifted to the canonical index before + // being lowered to HLO. + if (axes[d] < 0) { + axes[d] += x_shape.dims(); + } + OP_REQUIRES(ctx, !witnessed_axes[axes[d]], + errors::InvalidArgument("canonicalized axis ", axes[d], + " was repeated.")); + witnessed_axes[axes[d]] = true; } ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f..9e4c57c9bf73369662274f6b783418e18ff860c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -66,8 +66,8 @@ class SelectOp : public XlaOpKernel { // XLA. It seems we have to broadcast on the left and then Reshape // to get the dimensions in the right order. const auto dim_sizes = then_shape.dim_sizes(); - gtl::ArraySlice bdims = dim_sizes; - bdims.pop_front(); + absl::Span bdims = dim_sizes; + bdims.remove_prefix(1); cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b..c8a0f31a0375abacaca26688a23f4835e11c692e 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -115,7 +115,7 @@ class ExpandDimsOp : public XlaOpKernel { // accept legacy scalars, even when they should be forbidden by the graphdef // version. OP_REQUIRES(ctx, dim_shape.num_elements() == 1, - errors::InvalidArgument(strings::StrCat( + errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 6adc3c58de63ee70c26bed47eebef955893df4a5..537b71f3c0cf3622a8a45a717ac406da69f5c3c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Slice Op. +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 025ba827410f1a9f993a8a1855558a2daa86609b..d6bd927135c013ac1ec3f6547aef358dc2741896 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -33,7 +33,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = str_util::StartsWith(type_string(), "Log"); + log_ = absl::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 7327258c31f21f45ff7ffffbc9db7a2a70b4a14c..76b79be6f6f6b5ecbe9edcffb81f2834fdac9a56 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -23,10 +23,10 @@ namespace { void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); @@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4493539fe34f0ce635fdc58660d4ff90af9c9379..3293c13b21bc4825c83f494b7f2d48a9b3000f9e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -48,7 +48,7 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index df91900570107609c0f1c2281faaab8a5e65b98b..ee70f508a9586d5f47bd7bb7670506d4c92b369f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -111,7 +111,7 @@ class StackOp : public XlaOpKernel { xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; - string name = strings::StrCat("Stack: ", stack_name_); + string name = absl::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, TensorShape(), value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 1062399d91bd9a9bf8c3820c5ecac534c110746d..2b2e3de64fd0db9d99efa46ecaf7a0fefbae6645 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/strided_slice_op.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { @@ -46,9 +46,9 @@ class StridedSliceOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -72,8 +72,8 @@ class StridedSliceOp : public XlaOpKernel { shrink_axis_mask_, &dummy_processing_shape, &final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_end, slice_strides; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_end, slice_strides; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { @@ -127,9 +127,9 @@ class StridedSliceGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape processing_shape, final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -175,7 +175,7 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. - gtl::InlinedVector dimensions_to_reverse; + absl::InlinedVector dimensions_to_reverse; xla::PaddingConfig padding_config; for (int i = 0; i < processing_shape.dims(); ++i) { @@ -238,9 +238,9 @@ class StridedSliceAssignOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -287,8 +287,8 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_dims; for (int i = 0; i < begin.size(); ++i) { // TODO(phawkins): implement strides != 1 OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index be1814d8e3ae2c0ddad0134b9288e0ea084aa81b..94108b764fd32fc77520f9a8ea16065c27e6accf 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -122,7 +122,7 @@ Status GetTensorArrayShape(const XlaResource* resource, // relevant slice of 'operand'. xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, - const gtl::ArraySlice& update_dims, + absl::Span update_dims, const xla::XlaOp& start_indices) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); xla::XlaOp sum = xla::Add(current, update); @@ -167,7 +167,7 @@ class TensorArrayOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); XlaResource* var; - string name = strings::StrCat("TensorArray: ", tensor_array_name_); + string name = absl::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), dtype_, shape, value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 1233a37565d3a40c6dd2882b3139dedbf690a7b6..93d5996b5eaf10221b1d7067e7650b78cd6b8fef 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Tile Op. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -70,7 +70,7 @@ class TileOp : public XlaOpKernel { bool one_dimension_is_broadcasted_without_multiple = true; for (int i = 0; i < input_dims; ++i) { int multiple = literal.Get({i}); - OP_REQUIRES(ctx, multiple, + OP_REQUIRES(ctx, multiple >= 0, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiple)); int64 new_dim = input_shape.dim_size(i) * multiple; diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index be5e91138656716daddcc3c7a68dbb78ecb69103..7077c2e3a546e198bdb4ff944ea531f3158810f2 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -688,7 +688,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, } // grad_to_use = grad + 2 * l2_shrinkage * var - // new_accum = accum + grad_to_use * grad_to_use + // new_accum = accum + grad * grad // linear += grad_to_use - // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 @@ -704,7 +704,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, grad_to_use = grad; } - xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum = accum + xla::Square(grad); xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index f9148b394212777271f9eba51313ee17b19819af..6b303b31d43ce2249a87f25723caf34f84c8387d 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -61,7 +61,7 @@ class TransposeOp : public XlaOpKernel { std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). - gtl::InlinedVector bits(dims); + absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { const int32 d = perm[i]; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 1e8a376765d36ffa677ece06fbd131744299e04b..559414eeaa5fec75e5a9d1866baaf738c024cd15 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { cond_name_attr_ = *name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); body_name_attr_ = *name_attr; + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; body_options.is_entry_computation = false; + body_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; cond_options.is_entry_computation = false; + cond_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); @@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "while" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(builder, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); @@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, i)); } } + if (has_token_input_output_) { + // Set token output for this "while" op. + xla::XlaOp token_output = + xla::GetTupleElement(while_result, ctx->num_outputs()); + auto shape_or = builder->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the loop. for (int i = 0; i < body.resource_updates.size(); ++i) { @@ -301,6 +330,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 67edebabf9f643a919d0f06c228e2d224a49a2af..aeeff40e68f8b778628b9e85bd9b4ddcb73883a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel { private: NameAttrList cond_name_attr_; NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector token_input_nodes_; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..412afeaaad96842521fbd306f5b666e837e675fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +class XlaBroadcastHelperOp : public XlaOpKernel { + public: + explicit XlaBroadcastHelperOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp lhs = context->Input(0); + xla::XlaOp rhs = context->Input(1); + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); + const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; + const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; + + std::vector broadcast_dims; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims", + &broadcast_dims)); + if (broadcast_dims.empty()) { + OP_REQUIRES( + context, + lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || + rhs_shape.dims() == 0, + errors::InvalidArgument( + "If broadcast_dims is empty, both " + "arguments must have equal rank; " + "argument shapes, or at least one argument must be a scalar: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + return; + } + + OP_REQUIRES( + context, broadcast_dims.size() == min_rank_shape->dims(), + errors::InvalidArgument( + "broadcast_dims must have size equal to the smaller argument rank; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + std::vector sorted_broadcast_dims = broadcast_dims; + absl::c_sort(sorted_broadcast_dims); + std::set dims_set(broadcast_dims.begin(), broadcast_dims.end()); + OP_REQUIRES(context, + dims_set.size() == broadcast_dims.size() && + broadcast_dims == sorted_broadcast_dims, + errors::InvalidArgument( + "Duplicate or nonmonotonic dimension in broadcast_dims; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]")); + + std::vector broadcast_shape(max_rank_shape->dims(), 1LL); + for (int i = 0; i < broadcast_dims.size(); ++i) { + const int dim = broadcast_dims[i]; + OP_REQUIRES( + context, dim >= 0 && dim < broadcast_shape.size(), + errors::InvalidArgument( + "Invalid broadcast dimension (", dim, "); broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + broadcast_shape[dim] = min_rank_shape->dim_size(i); + } + xla::PrimitiveType type = context->input_xla_type(0); + xla::Shape broadcast_xla_shape = + xla::ShapeUtil::MakeShape(type, broadcast_shape); + if (broadcast_lhs) { + lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + } else { + rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + } + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + } + + private: + xla::DotDimensionNumbers dnums_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp); +}; + +REGISTER_XLA_OP( + Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"), + XlaBroadcastHelperOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fecc7c556eb4121b912796e5811632c46769b479 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaConvOp : public XlaOpKernel { + public: + explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + const TensorShape padding_shape = context->InputShape("padding"); + std::vector window_strides; + std::vector lhs_dilation; + std::vector rhs_dilation; + int64 feature_group_count; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation", + &lhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation", + &rhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar( + "feature_group_count", &feature_group_count)); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; + } + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::ConvGeneralDilated( + context->Input(0), context->Input(1), window_strides, padding, + lhs_dilation, rhs_dilation, dnums_, feature_group_count, + &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::ConvolutionDimensionNumbers dnums_; + xla::PrecisionConfig precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); +}; + +REGISTER_XLA_OP(Name("XlaConv") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("lhs_dilation") + .CompileTimeConstInput("rhs_dilation") + .CompileTimeConstInput("feature_group_count") + .CompileTimeConstInput("padding"), + XlaConvOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..40b15b5579ab9862b9d30df74af9877c98c4aa2c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDotOp : public XlaOpKernel { + public: + explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1), + dnums_, &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::DotDimensionNumbers dnums_; + xla::PrecisionConfig precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); +}; + +REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59502d83c7338bd1b05b3323a97761fff2da186a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaPadOp : public XlaOpKernel { + public: + explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape padding_value_shape = + context->InputShape("padding_value"); + + std::vector padding_low; + std::vector padding_high; + std::vector padding_interior; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low", + &padding_low)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high", + &padding_high)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "padding_interior", &padding_interior)); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape), + errors::InvalidArgument("padding_value must be a scalar")); + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == padding_low.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_interior.size(), + errors::InvalidArgument( + "The size of padding_interior must be equal to the input " + "rank (", + padding_interior.size(), " vs. ", rank, ")")); + + auto non_negative = [](int64 x) { return x >= 0; }; + OP_REQUIRES( + context, absl::c_all_of(padding_low, non_negative), + errors::InvalidArgument("padding_low must be non-negative, got [", + absl::StrJoin(padding_low, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_high, non_negative), + errors::InvalidArgument("padding_high must be non-negative, got [", + absl::StrJoin(padding_high, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_interior, non_negative), + errors::InvalidArgument("padding_interior must be non-negative, got [", + absl::StrJoin(padding_interior, ","), "]")); + + xla::PaddingConfig padding_config; + for (int i = 0; i < rank; ++i) { + auto* dim = padding_config.add_dimensions(); + dim->set_edge_padding_low(padding_low[i]); + dim->set_edge_padding_high(padding_high[i]); + dim->set_interior_padding(padding_interior[i]); + } + + xla::XlaOp output = + xla::Pad(context->Input("input"), context->Input("padding_value"), + padding_config); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp); +}; + +REGISTER_XLA_OP(Name("XlaPad") + .CompileTimeConstInput("padding_low") + .CompileTimeConstInput("padding_high") + .CompileTimeConstInput("padding_interior"), + XlaPadOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc2425f37bfa793ce3a106b635c9dffd15b975ff --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaReduceOp : public XlaOpKernel { + public: + explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_)); + OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce", + &dimensions_to_reduce_)); + std::set dims_set(dimensions_to_reduce_.begin(), + dimensions_to_reduce_.end()); + OP_REQUIRES( + context, dims_set.size() == dimensions_to_reduce_.size(), + errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " + "argument to XlaReduce")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape init_value_shape = context->InputShape("init_value"); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), + errors::InvalidArgument("init_value must be a scalar")); + + auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; + OP_REQUIRES(context, + rank >= dimensions_to_reduce_.size() && + absl::c_all_of(dimensions_to_reduce_, dim_in_range), + errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce")); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *reducer_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of XlaReduce reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + xla::XlaOp output = + xla::Reduce(context->Input("input"), context->Input("init_value"), + *reducer.computation, dimensions_to_reduce_); + context->SetOutput(0, output); + } + + private: + const NameAttrList* reducer_; + std::vector dimensions_to_reduce_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); +}; + +REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..089776fcf74fcf6b363dfff5de8d86d7449eacd6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSelectAndScatterOp : public XlaOpKernel { + public: + explicit XlaSelectAndScatterOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_)); + OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + std::vector window_dimensions; + std::vector window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides.size(), " vs. ", rank, ")")); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; + + // Build the select function. + XlaCompiler::Argument select_arg; + select_arg.kind = XlaCompiler::Argument::kParameter; + select_arg.type = dtype; + select_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult select; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *select_computation_, + {select_arg, select_arg}, &select)); + + xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {}); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(select.xla_output_shape, + select_output_shape), + errors::InvalidArgument( + "Invalid output shape of XlaSelectAndScatter select. Expected ", + xla::ShapeUtil::HumanString(select_output_shape), " got ", + xla::ShapeUtil::HumanString(select.xla_output_shape))); + + // Build the scatter function. + XlaCompiler::Argument scatter_arg; + scatter_arg.kind = XlaCompiler::Argument::kParameter; + scatter_arg.type = dtype; + scatter_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult scatter; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *scatter_computation_, + {scatter_arg, scatter_arg}, &scatter)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of scatter. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(scatter.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; + } + + xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding( + context->Input("operand"), *select.computation, window_dimensions, + window_strides, padding, context->Input("source"), + context->Input("init_value"), *scatter.computation); + context->SetOutput(0, output); + } + + private: + const NameAttrList* select_computation_; + const NameAttrList* scatter_computation_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp); +}; + +REGISTER_XLA_OP(Name("XlaSelectAndScatter") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + XlaSelectAndScatterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index cb7a40e23d539f758d963791f1c2b4d37374ade5..8597e7f139d8d32b7e08782e70a4ee44d02618f2 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,8 +44,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:lib", ], @@ -78,8 +78,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", @@ -104,6 +104,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -119,6 +120,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", @@ -165,6 +167,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -202,6 +205,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index f666d22ea44216beef74608bb4d9f33fb2fe82c6..5400e8834cb9807f6dd71abe7789b2672e29e905 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -27,7 +27,8 @@ limitations under the License. namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y) { + bool transpose_y, bool conjugate_x, bool conjugate_y, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -95,15 +96,9 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose - // HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs); - } + xla::PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); xla::DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); @@ -112,7 +107,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return xla::DotGeneral(x, y, dot_dnums); + + return xla::DotGeneral(x, y, dot_dnums, &precision_proto); }); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76..6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -43,9 +43,11 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false); +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117..ab3d0a566839343828d176d9a46672824e425613 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -49,20 +49,22 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { +xla::XlaOp CholeskyUnblocked(xla::XlaOp a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int n_dims = xla::ShapeUtil::Rank(a_shape); const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); + auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - 2); xla::XlaOp l = xla::ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::Shape col_shape; @@ -101,7 +103,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -121,7 +124,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // r.T) auto dot = BatchDot(body_l, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -145,7 +149,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -181,14 +186,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); - auto factorized = CholeskyUnblocked(x); + auto factorized = CholeskyUnblocked(x, precision); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 1bef9bb166c576ec665bb48265b4da200ddca2a0..9a561c34b92ee45059f2a05336e682838f8e36e2 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -30,7 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); +xla::XlaOp Cholesky( + xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index fc0c1ee838190b1f1a7ca5b901c97e0a35232a97..6b3f2b6e065b5c99e2d0248237369ecc30188aa5 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -65,9 +65,9 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice batch_dims, - const int64 m, xla::XlaOp* v, xla::XlaOp* tau, - xla::XlaOp* beta) { +xla::Status House(xla::XlaOp x, xla::XlaOp k, + absl::Span batch_dims, const int64 m, + xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { xla::XlaBuilder* const builder = x.builder(); TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); const xla::PrimitiveType type = x_shape.element_type(); @@ -149,7 +149,8 @@ struct QRBlockResult { xla::XlaOp taus; // Shape: [..., n] xla::XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock(xla::XlaOp a) { +xla::StatusOr QRBlock( + xla::XlaOp a, xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -172,7 +173,7 @@ xla::StatusOr QRBlock(xla::XlaOp a) { std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); auto qr_body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto a = values[0]; auto vs = values[1]; @@ -190,8 +191,12 @@ xla::StatusOr QRBlock(xla::XlaOp a) { auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = BatchDot(v_broadcast, a); - vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true); + auto vva = + BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + vva = + BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -250,14 +255,15 @@ xla::StatusOr QRBlock(xla::XlaOp a) { // There is no need to return Y since at termination of the loop it is equal to // vs. xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n) { + xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, + xla::XlaOp taus, int64 m, int64 n, + xla::PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; auto body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto w = values[0]; auto y = values[1]; @@ -272,9 +278,12 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true); + auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv); + auto wyv = + BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); auto z = xla::Mul( -beta, v + wyv, @@ -321,8 +330,9 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition(xla::XlaOp a, - int64 block_size) { +xla::StatusOr QRDecomposition( + xla::XlaOp a, bool full_matrices, int64 block_size, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -352,33 +362,47 @@ xla::StatusOr QRDecomposition(xla::XlaOp a, int64 k = std::min(block_size, p - i); auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); - TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block)); + TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); // Compute the I-WY block representation of a product of Householder // matrices. - TF_ASSIGN_OR_RETURN(auto w, - ComputeWYRepresentation(type, batch_dims, qr_block.vs, - qr_block.taus, m - i, k)); + TF_ASSIGN_OR_RETURN( + auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, + qr_block.taus, m - i, k, precision)); auto y = qr_block.vs; // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true); - a_update = BatchDot(y, a_update); + auto a_update = + BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + a_update = + BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w); - q_update = - BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true); + auto q_update = + BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + q_update = BatchDot(q_update, y, /*transpose_x=*/false, + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } QRDecompositionResult result; + + // full_matrices is false when only a partial result in needed. Slice to the + // needed dimensions here. + if (!full_matrices) { + q = SliceInMinorDims(q, {0, 0}, {m, p}); + a = SliceInMinorDims(a, {0, 0}, {p, n}); + } result.q = q; result.r = a; return result; diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index abd2316ac961f583dd29f90f43cf6209de30bd6a..24b537ac8b63b93e734c3d0e335ea455f7d51a54 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -32,8 +33,9 @@ struct QRDecompositionResult { xla::XlaOp r; }; -xla::StatusOr QRDecomposition(xla::XlaOp a, - int64 block_size = 128); +xla::StatusOr QRDecomposition( + xla::XlaOp a, bool full_matrices, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index ba22eff73abab11abeb57283c63318b2e50a9ca1..38dfde165df47ca78a25a068a901cd1071aa55e2 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -40,9 +40,9 @@ xla::StatusOr XlaScatter( TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); - gtl::ArraySlice indices_dims = + absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - gtl::ArraySlice buffer_dims = + absl::Span buffer_dims = xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains @@ -58,7 +58,7 @@ xla::StatusOr XlaScatter( ") must be <= the rank of the buffer (shape: ", xla::ShapeUtil::HumanString(buffer_shape), ")"); } - indices_dims.pop_back(); + indices_dims.remove_suffix(1); } int64 num_indices = 1; @@ -107,7 +107,7 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 04fa10108cef66f429392951eea70e59643a2d29..6524c2a9b1ada632d80edd234272760c2b545cc4 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // We can grab entire blocks using gather if (n > block_size) { // Construct the starting indices of the diagonal blocks - auto gather_indices = + auto start_indices = Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), xla::ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), @@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Gather the diagonal blocks xla::GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(ndims - 1); - dim_numbers.add_output_window_dims(ndims); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims); + dim_numbers.add_start_index_map(ndims - 2); + dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, gather_indices, dim_numbers, - /*window_bounds=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, + /*slice_sizes=*/{block_size, block_size}); } // The last block might be smaller than the block size, @@ -111,7 +111,8 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { } xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a) { + bool transpose_a, bool conjugate_a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is @@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - auto update = -DotGeneral(input_row, body_out, dnums); + xla::PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); body_out = DynamicUpdateSlice(body_out, update, start_indices); @@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, - xla::XlaOp inv_diag_blocks, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp SolveWithInvertedDiagonalBlocks( + xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false); + remainder = b_row - BatchDot(a_row, x, transpose_a, false, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a); + remainder = b_row - BatchDot(x, a_row, false, transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } } @@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::ConstantR0WithType(builder, xla::S32, j * block_size); std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = BatchDot(inv_block, remainder, transpose_a, false); + x_update = + BatchDot(inv_block, remainder, transpose_a, false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); } else { - x_update = BatchDot(remainder, inv_block, false, transpose_a); + x_update = + BatchDot(remainder, inv_block, false, transpose_a, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size) { + int64 block_size, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = - InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, + conjugate_a, precision); // We now find the solution using GEMMs - auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, - lower, transpose_a, conjugate_a); + auto x = + SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, + transpose_a, conjugate_a, precision); return x; }); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 555760b7efabddfb25c9135b109a1c48b487415e..2303234f361e54cd2a0ad495cb03b371bed76877 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -57,9 +57,10 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128); +xla::XlaOp TriangularSolve( + xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7..804671fbc75b0a5a6e04b204822b6f084013cd8b 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::C64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: - literal = std::move( - *xla::LiteralUtil::CreateR0(static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::F16: - literal = std::move(*xla::LiteralUtil::CreateR0( - static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; @@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end) { +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_RET_CHECK(start.size() == end.size()); @@ -124,9 +124,10 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); + auto major_dims = xla::AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); // Prepends 0s in the major dim std::vector padded_start(n_dims, 0); @@ -143,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, }); } -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys) { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { std::vector output(xs.size() + ys.size()); std::copy(xs.begin(), xs.end(), output.begin()); std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); @@ -152,8 +153,8 @@ std::vector ConcatVectors(gtl::ArraySlice xs, } xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes) { + absl::Span starts, + absl::Span sizes) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -161,9 +162,10 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); + auto major_dims = xla::AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); auto padded_starts = PrependZerosInMajorDims(x, starts); auto padded_sizes = ConcatVectors(major_dims, sizes); return xla::DynamicSlice(x, padded_starts, padded_sizes); @@ -171,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, } xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. @@ -189,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -204,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts) { + absl::Span starts) { auto padded_starts = PrependZerosInMajorDims(x, starts); return xla::DynamicUpdateSlice(x, update, padded_starts); } xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts) { + absl::Span starts) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index b4905c952820a45371e090aa98466654e2db9661..80e9e5b002d49581209e608b98606e02709c5876 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. @@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end); // Returns the concatenation of `xs` and `ys`. -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys); +std::vector ConcatVectors(absl::Span xs, + absl::Span ys); // Performs a dynamic slice in the minor dimensions of a Tensor. xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes); + absl::Span starts, + absl::Span sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts); + absl::Span starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index d64394f1401d7ceea004a59c991ef6f4a1c58b41..594ab1dfd0700f47501712183f6efe62d17e15e7 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -24,7 +24,7 @@ namespace tensorflow { xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; @@ -47,7 +47,7 @@ xla::StatusOr> XlaWhileLoop( // Build the condition. std::unique_ptr cond_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { auto parameter = xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); @@ -61,7 +61,7 @@ xla::StatusOr> XlaWhileLoop( // Build the body. std::unique_ptr body_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_body")); + builder->CreateSubBuilder(absl::StrCat(name, "_body")); { auto parameter = xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); @@ -84,15 +84,15 @@ xla::StatusOr> XlaWhileLoop( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { auto while_cond_fn = - [&](gtl::ArraySlice values, + [&](absl::Span values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, + auto while_body_fn = [&](absl::Span values, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::XlaOp iteration = values[0]; diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 9493b1f109be0725f7f733b9f9da664264275a69..f2134bb4495a12b8342961d96f70e7737f816c7d 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,24 +19,24 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(gtl::ArraySlice, +typedef std::function(absl::Span, xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. typedef std::function>( - gtl::ArraySlice, xla::XlaBuilder*)> + absl::Span, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -50,7 +50,7 @@ typedef std::function>( xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. @@ -59,13 +59,13 @@ xla::StatusOr> XlaWhileLoop( // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. typedef std::function>( - xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> + xla::XlaOp, absl::Span, xla::XlaBuilder*)> ForEachIndexBodyFunction; xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2fb66913ada375d53512b9a1115326b3cc2afea4..20103ec3ae00b57723e05326dbbb1b0f6e1a671a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,9 +32,25 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal) { +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), + host_tensor->shape(), &xla_shape)); + return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal) { + *literal = xla::MutableBorrowingLiteral( + static_cast(DMAHelper::base(host_tensor)), xla_shape); + + return Status::OK(); +} + +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal) { std::vector buf_ptrs; buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 0610a57029e72dff79a84742346f78a42b7f4ff1..1db7470ee2a839099454b772d4833492e033bc92 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,11 +18,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -30,12 +30,21 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal); +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index a3404c2b3df7bf25011359d1f5f5b88c29a3f83b..15f4c38da29507da9e092c1d5725b5f95a81d1b9 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -22,51 +22,61 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. - { - std::vector int64_values = {1, 2, 3}; - std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); - Tensor host_tensor; - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) - .error_message()); - EXPECT_EQ( - "Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); - EXPECT_TRUE( - LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int64_values)); - } + std::vector int64_values = {1, 2, 3}; + xla::Literal int64_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int64_values)); +} + +template +using LiteralUtilTest = ::testing::Test; +using Types = + ::testing::Types, std::pair, + std::pair, std::pair, + std::pair>; + +TYPED_TEST_CASE(LiteralUtilTest, Types); + +TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { + using int_type = typename TypeParam::first_type; + using qint_type = typename TypeParam::second_type; - { - // Repeat tests with int32. - Tensor host_tensor; - std::vector int32_values = {10, 11}; - std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int32_values)); + Tensor host_tensor; + std::vector int_values = {10, 11}; + xla::Literal int_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) - .ok()); - std::vector qint32_values = {10, 11}; - test::ExpectTensorEqual(host_tensor, - test::AsTensor(qint32_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, + &host_tensor) + .ok()); + std::vector qint_values = {10, 11}; + test::ExpectTensorEqual(host_tensor, + test::AsTensor(qint_values)); - EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) - .error_message()); - } + EXPECT_EQ( + error::INVALID_ARGUMENT, + LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code()); } +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index ace6fd1d8eeaf439509a7b75d8d986997c392e73..4dce0a2102cf9c782850ccc7af4f14b59bd51e53 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -11,6 +11,8 @@ cc_library( srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a59c77f5c3a309abe8f6fbab1e48455d54e8fae5..733eeed3c661c9ed683f0fb7fd90f7f997b8dc2b 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,11 +13,127 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace { + +// Helper shape function for operators that return an output with the same rank +// as their first input. +Status UnchangedRank(shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); +} + +REGISTER_OP("XlaBroadcastHelper") + .Input("lhs: T") + .Input("rhs: T") + .Input("broadcast_dims: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("lhs_output: T") + .Output("rhs_output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Helper operator for performing XLA-style broadcasts + +Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +for binary operators. + +lhs: the LHS input tensor +rhs: the RHS input tensor +broadcast_dims: an XLA-style broadcast dimension specification +lhs_output: the broadcasted LHS tensor +rhs_output: the broadcasted RHS tensor +)doc"); + +REGISTER_OP("XlaConv") + .Input("lhs: T") + .Input("rhs: T") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("lhs_dilation: Tindices") + .Input("rhs_dilation: Tindices") + .Input("feature_group_count: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + +lhs: the input tensor +rhs: the kernel tensor +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +lhs_dilation: dilation to apply between input elements +rhs_dilation: dilation to apply between kernel elements +feature_group_count: number of feature groups for grouped convolution. +dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfig proto. +)doc"); + +REGISTER_OP("XlaDot") + .Input("lhs: T") + .Input("rhs: T") + .Attr("T: numbertype") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + +lhs: the LHS tensor +rhs: the RHS tensor +dimension_numbers: a serialized xla::DotDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfig proto. +)doc"); + +REGISTER_OP("XlaDynamicSlice") + .Input("input: T") + .Input("start_indices: Tindices") + .Input("size_indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA DynamicSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +. + +DynamicSlice extracts a sub-array from the input array at dynamic +start_indices. The size of the slice in each dimension is passed in +size_indices, which specify the end point of exclusive slice intervals in each +dimension -- [start, start + size). The shape of start_indices must have rank 1, +with dimension size equal to the rank of operand. + +input: A `Tensor` of type T. + +start_indices: Rank 1 tensor of N integers containing the starting indices of + the slice for each dimension. Value must be greater than or equal to zero. + +start_indices: List of N integers containing the slice size for each + dimension. Each value must be strictly greater than zero, and start + size + must be less than or equal to the size of the dimension to avoid + implementation defined behavior. +)doc"); REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") @@ -73,6 +189,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); +REGISTER_OP("XlaPad") + .Input("input: T") + .Input("padding_value: T") + .Input("padding_low: Tindices") + .Input("padding_high: Tindices") + .Input("padding_interior: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Pad operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#pad +. + +input: A `Tensor` of type T. +padding_value: A scalar `Tensor` of type T. +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +padding_interior: the padding to apply between each input element. +output: A `Tensor` of type T. +)doc"); + REGISTER_OP("XlaRecv") .Output("tensor: dtype") .Attr("dtype: type") @@ -98,17 +237,58 @@ tensor_name: A string key that identifies the channel. shape: The shape of the tensor. )doc"); +REGISTER_OP("XlaReduce") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + if (rank < dimensions_to_reduce.size() || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce"); + } + c->set_output( + 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn(UnchangedRank) .Doc(R"doc( Wraps the XLA ReduceWindow operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . @@ -118,8 +298,35 @@ init_value: a scalar representing the initial value for the reduction computation: a reducer function to apply window_dimensions: the shape of the window window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. +padding: the padding to apply at the start and end of each input dimensions +)doc"); + +REGISTER_OP("XlaSelectAndScatter") + .Input("operand: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("source: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("select: func") + .Attr("scatter: func") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA SelectAndScatter operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter +. + +operand: the input tensor +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +source: a tensor of values to scatter +init_value: a scalar representing the initial value for the output tensor +select: a selection function to apply +scatter: a scatter function to apply )doc"); REGISTER_OP("XlaSend") @@ -179,4 +386,5 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 42b6292f79ffddd155c05758a1420a2a583eb0c6..69ca39436013ec5cf09ba502a1540d5df322e213 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -28,5 +28,6 @@ py_library( srcs = ["xla.py"], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/xla:xla_data_proto_py", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58..27dd18a9bbd5aceece41aaf61eb185acb537b3b6 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -15,11 +15,12 @@ """Experimental library that exposes XLA operations directly in TensorFlow. It is sometimes useful to be able to build HLO programs directly from -TensorFlow. This file provides Tensorflow operators that map as closely as -possible to HLO operators. +TensorFlow. This file provides Tensorflow operators that mirror the semantics of +HLO operators as closely as possible. -There is no promise of backward or forward compatibility for operators defined -in this module. +Note: There is no promise of backward or forward compatibility for operators +defined in this module. This is primarily because the underlying HLO operators +do not promise backward or forward compatibility. """ from __future__ import absolute_import @@ -27,11 +28,292 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops -# TODO(phawkins): provide wrappers for all XLA operators. +# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing +# ops include: +# infeed/outfeed (available via tf.contrib.tpu) +# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) +# conditional +# gather/scatter +# collapse +# This file reuses builtin names (following XLA's names, so we can call things +# like xla.max), so we capture the builtin versions here. +# pylint: disable=redefined-builtin +_max = max +_min = min +_slice = slice # pylint: disable=invalid-name + +constant = constant_op.constant + +# Unary operators. + +# For most arithmetic operators there is a TensorFlow operator +# that exactly corresponds to each XLA operator. Rather than defining +# XLA-specific variants, we reuse the corresponding TensorFlow operator. +# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 +# wrap every HLO operator, because that would allow us to be confident that the +# semantics match. + + +def _unary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def unary_op_wrapper(x, name=None): + return fn(x, name=name) + + return unary_op_wrapper + + +abs = _unary_op(math_ops.abs) +# TODO(phawkins): implement clz. +conj = _unary_op(math_ops.conj) +cos = _unary_op(math_ops.cos) +ceil = _unary_op(math_ops.ceil) +digamma = _unary_op(math_ops.digamma) +erf = _unary_op(math_ops.erf) +erfc = _unary_op(math_ops.erfc) +# TODO(phawkins): implement erfinv +exp = _unary_op(math_ops.exp) +expm1 = _unary_op(math_ops.expm1) +floor = _unary_op(math_ops.floor) +imag = _unary_op(math_ops.imag) +is_finite = _unary_op(math_ops.is_finite) +lgamma = _unary_op(math_ops.lgamma) +log = _unary_op(math_ops.log) +log1p = _unary_op(math_ops.log1p) +logical_not = _unary_op(math_ops.logical_not) +neg = _unary_op(math_ops.neg) +real = _unary_op(math_ops.real) +# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for +# numbers halfway between two integers. +round = _unary_op(math_ops.round) +sin = _unary_op(math_ops.sin) +sign = _unary_op(math_ops.sign) +tanh = _unary_op(math_ops.tanh) + +# Binary operators + +# The main difference between TensorFlow and XLA binary ops is the broadcasting +# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA +# requires an explicit specification of which dimensions to broadcast if the +# arguments have different ranks. + + +def _broadcasting_binary_op(fn): + """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" + + def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): + """Inner wrapper function.""" + broadcast_dims = broadcast_dims or [] + broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) + # Rather than relying on having static shape information in the TensorFlow + # graph, we use an XlaBroadcastHelper op that can compute the correct shapes + # at JIT compilation time. + x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) + return fn(x, y, name=name) + + return broadcasting_binary_op_wrapper + + +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + dtypes.int8: dtypes.uint8, + dtypes.int16: dtypes.uint16, + dtypes.int32: dtypes.uint32, + dtypes.int64: dtypes.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = { + dtypes.uint8: dtypes.int8, + dtypes.uint16: dtypes.int16, + dtypes.uint32: dtypes.int32, + dtypes.uint64: dtypes.int64, +} + + +def _shift_right_logical_helper(x, y, name=None): + """Performs an integer right logical shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + signed = dtype in _SIGNED_TO_UNSIGNED_TABLE + if signed: + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] + x = math_ops.cast(x, unsigned_dtype) + y = math_ops.cast(y, unsigned_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if signed: + output = math_ops.cast(output, dtype) + return output + + +def _shift_right_arithmetic_helper(x, y, name=None): + """Performs an integer right arithmetic shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE + if unsigned: + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] + x = math_ops.cast(x, signed_dtype) + y = math_ops.cast(y, signed_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if unsigned: + output = math_ops.cast(output, dtype) + return output + + +add = _broadcasting_binary_op(math_ops.add) +sub = _broadcasting_binary_op(math_ops.sub) +mul = _broadcasting_binary_op(math_ops.mul) +div = _broadcasting_binary_op(math_ops.div) +rem = _broadcasting_binary_op(gen_math_ops.mod) +max = _broadcasting_binary_op(math_ops.maximum) +min = _broadcasting_binary_op(math_ops.minimum) +atan2 = _broadcasting_binary_op(math_ops.atan2) +complex = _broadcasting_binary_op(math_ops.complex) +logical_and = _broadcasting_binary_op(math_ops.logical_and) +logical_or = _broadcasting_binary_op(math_ops.logical_or) +logical_xor = _broadcasting_binary_op(math_ops.logical_xor) +eq = _broadcasting_binary_op(math_ops.equal) +ne = _broadcasting_binary_op(math_ops.not_equal) +ge = _broadcasting_binary_op(math_ops.greater_equal) +gt = _broadcasting_binary_op(math_ops.greater) +le = _broadcasting_binary_op(math_ops.less_equal) +lt = _broadcasting_binary_op(math_ops.less) +pow = _broadcasting_binary_op(math_ops.pow) +shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) +shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) +shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) + + +def _binary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def binary_op_wrapper(x, y, name=None): + return fn(x, y, name=name) + + return binary_op_wrapper + + +transpose = _binary_op(array_ops.transpose) +rev = _binary_op(array_ops.reverse) + +bitcast_convert_type = array_ops.bitcast + + +def broadcast(x, dims, name=None): + x = ops.convert_to_tensor(x) + shape = array_ops.concat( + [constant_op.constant(dims), + array_ops.shape(x)], axis=0) + return array_ops.broadcast_to(x, shape, name=name) + + +def clamp(a, x, b, name=None): + return min(max(a, x, name=name), b, name=name) + + +concatenate = array_ops.concat + + +def conv(lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count=1, + precision_config=None, + name=None): + """Wraps the XLA ConvGeneralDilated operator. + + ConvGeneralDilated is the most general form of XLA convolution and is + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + + Args: + lhs: the input tensor + rhs: the kernel tensor + window_strides: the inter-window strides + padding: the padding to apply at the start and end of each input dimensions + lhs_dilation: dilation to apply between input elements + rhs_dilation: dilation to apply between kernel elements + dimension_numbers: a `ConvolutionDimensionNumbers` proto. + feature_group_count: number of feature groups for grouped convolution. + precision_config: a `PrecisionConfigProto` proto. + name: an optional name for the operator + + Returns: + A tensor representing the output of the convolution. + """ + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_conv( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +convert_element_type = math_ops.cast + + +def dot(lhs, rhs, name=None): + return math_ops.tensordot(lhs, rhs, axes=1, name=name) + + +def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_dot( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +# TODO(phawkins): generalize tf.pad to support interior padding, and then remove +# the XLA-specific pad operator. +pad = gen_xla_ops.xla_pad + + +def random_normal(mu, sigma, dims, name=None): + mu = ops.convert_to_tensor(mu) + return random_ops.random_normal( + dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) + + +def random_uniform(minval, maxval, dims, name=None): + minval = ops.convert_to_tensor(minval) + return random_ops.random_uniform( + dims, minval, maxval, dtype=minval.dtype, name=name) + + +recv = gen_xla_ops.xla_recv +reduce = gen_xla_ops.xla_reduce + def reduce_window(operand, init, @@ -61,22 +343,38 @@ def reduce_window(operand, """ window_strides = window_strides or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) - padding_low = [x for (x, _) in padding] - padding_high = [y for (_, y) in padding] return gen_xla_ops.xla_reduce_window( - operand, - init, - reducer, - window_dimensions, - window_strides, - padding_low, - padding_high, + input=operand, + init_value=init, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + computation=reducer, name=name) -recv = gen_xla_ops.xla_recv +def reshape(x, new_sizes, dimensions=None, name=None): + if dimensions is not None: + x = array_ops.transpose(x, dimensions) + x = array_ops.reshape(x, new_sizes, name=name) + return x + + +def select(condition, x, y, name=None): + return array_ops.where(condition, x, y, name) + + +select_and_scatter = gen_xla_ops.xla_select_and_scatter send = gen_xla_ops.xla_send -sort = gen_xla_ops.xla_sort +def slice(x, start_dims, limit_dims, strides): + spec = [ + _slice(start, limit, stride) + for (start, limit, stride) in zip(start_dims, limit_dims, strides) + ] + return x[tuple(spec)] + + +sort = gen_xla_ops.xla_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..20f2ce2919701731ef6e90d368b67545af95e8f9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "absl/algorithm/container.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( + XlaResourceOpKind op_kind) { + switch (op_kind) { + case XlaResourceOpKind::kRead: + return "Read"; + case XlaResourceOpKind::kWrite: + return "Write"; + case XlaResourceOpKind::kReadWrite: + return "Modify"; + } +} + +static gtl::FlatMap* +CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap; + + auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) { + auto insert_result = + result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); + CHECK(insert_result.second); + }; + + auto kRead = XlaResourceOpKind::kRead; + auto kWrite = XlaResourceOpKind::kWrite; + auto kReadWrite = XlaResourceOpKind::kReadWrite; + + auto kVariable = XlaResourceKind::kVariable; + auto kStack = XlaResourceKind::kStack; + auto kTensorArray = XlaResourceKind::kTensorArray; + + // clang-format off + add("AssignAddVariableOp" , kReadWrite, kVariable); + add("AssignSubVariableOp" , kReadWrite, kVariable); + add("AssignVariableOp" , kWrite, kVariable); + add("ReadVariableOp" , kRead, kVariable); + add("ResourceApplyAdaMax" , kReadWrite, kVariable); + add("ResourceApplyAdadelta" , kReadWrite, kVariable); + add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradDA" , kReadWrite, kVariable); + add("ResourceApplyAdam" , kReadWrite, kVariable); + add("ResourceApplyAddSign" , kReadWrite, kVariable); + add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable); + add("ResourceApplyFtrl" , kReadWrite, kVariable); + add("ResourceApplyFtrlV2" , kReadWrite, kVariable); + add("ResourceApplyGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyPowerSign" , kReadWrite, kVariable); + add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); + add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyRMSProp" , kReadWrite, kVariable); + add("ResourceGather" , kRead, kVariable); + add("ResourceScatterAdd" , kReadWrite, kVariable); + add("ResourceScatterDiv" , kReadWrite, kVariable); + add("ResourceScatterMax" , kReadWrite, kVariable); + add("ResourceScatterMin" , kReadWrite, kVariable); + add("ResourceScatterMul" , kReadWrite, kVariable); + add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdUpdate" , kReadWrite, kVariable); + add("ResourceScatterSub" , kReadWrite, kVariable); + add("ResourceScatterUpdate" , kReadWrite, kVariable); + add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("VarIsInitializedOp" , kRead, kVariable); + add("VariableShape" , kRead, kVariable); + + add("StackV2" , kWrite, kStack); + add("StackCloseV2" , kRead, kStack); + add("StackPopV2" , kReadWrite, kStack); + add("StackPushV2" , kReadWrite, kStack); + + add("TensorArrayV3" , kWrite, kTensorArray); + add("TensorArrayConcatV3" , kRead, kTensorArray); + add("TensorArrayGatherV3" , kRead, kTensorArray); + add("TensorArrayScatterV3" , kWrite, kTensorArray); + add("TensorArrayGradV3" , kRead, kTensorArray); + add("TensorArrayCloseV3" , kRead, kTensorArray); + add("TensorArrayReadV3" , kRead, kTensorArray); + add("TensorArraySizeV3" , kRead, kTensorArray); + add("TensorArraySplitV3" , kWrite, kTensorArray); + add("TensorArrayWriteV3" , kWrite, kTensorArray); + // clang-format on + + return result; +} + +static const gtl::FlatMap& +GetStaticResourceOpInfoMap() { + static gtl::FlatMap* op_info_map = + CreateResourceOpInfoMap(); + return *op_info_map; +} + +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { + const gtl::FlatMap& op_infos = + GetStaticResourceOpInfoMap(); + auto it = op_infos.find(op); + return it == op_infos.end() ? nullptr : &it->second; +} + +namespace resource_op_table_internal { +std::vector GetKnownResourceOps() { + std::vector result; + for (const auto& p : GetStaticResourceOpInfoMap()) { + result.push_back(p.first); + } + absl::c_sort(result); + return result; +} +} // namespace resource_op_table_internal +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..61c7a56ff0d4adb75e93ced3155b37102763c652 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" + +// Exposes information about the resource operations supported by tf2xla in a +// structured form. + +namespace tensorflow { +enum class XlaResourceOpKind { + kRead, // Only reads from resources. + kWrite, // Only writes to resources. + kReadWrite // Reads from and writes to resources. +}; + +enum class XlaResourceKind { + kVariable, // Operates on resource variables. + kStack, // Operates on stacks. + kTensorArray // Operates on tensor arrays. +}; + +class XlaResourceOpInfo { + public: + explicit XlaResourceOpInfo(XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) + : op_kind_(op_kind), resource_kind_(resource_kind) {} + + XlaResourceOpKind kind() const { return op_kind_; } + XlaResourceKind resource_kind() const { return resource_kind_; } + + static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind); + + private: + XlaResourceOpKind op_kind_; + XlaResourceKind resource_kind_; +}; + +// Returns a XlaResourceOpInfo describing `op` if it is a resource operation +// supported by tf2xla, otherwise returns null (i.e. if this returns null then +// `op` is either not a resource operation or is unsupported by XLA). +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op); + +namespace resource_op_table_internal { +// NB! Implementation detail exposed for unit testing, do not use. +// +// Returns the set of resource operations known by this module. +std::vector GetKnownResourceOps(); +} // namespace resource_op_table_internal + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a85ef040a7b65c2f6e405c3444eaef3019137b4b --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} + +bool HasResourceInputOrOutput(const OpDef& op_def) { + return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) || + absl::c_any_of(op_def.output_arg(), IsResourceArgDef); +} + +TEST(ResourceOperationTableTest, HaveAllResourceOps) { + gtl::FlatMap known_resource_ops; + for (absl::string_view known_resource_op : + resource_op_table_internal::GetKnownResourceOps()) { + ASSERT_TRUE( + known_resource_ops.insert({string(known_resource_op), false}).second); + } + + std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const string& xla_op_name : xla_op_names) { + const OpDef* op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); + if (HasResourceInputOrOutput(*op_def)) { + EXPECT_EQ(known_resource_ops.count(xla_op_name), 1) + << "Unknown resource op " << xla_op_name; + known_resource_ops[xla_op_name] = true; + } + } + + std::vector unnecessary_resource_ops; + for (const auto& pair : known_resource_ops) { + if (!pair.second) { + unnecessary_resource_ops.push_back(pair.first); + } + } + + EXPECT_TRUE(unnecessary_resource_ops.empty()) + << "Stale resource ops:\n" + << absl::StrJoin(unnecessary_resource_ops, "\n"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 9d1992205b02665b99b1bd15b7b65a1fb8c35a51..b589512dcdfa32050281120aba6a5ae89a980c2f 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); @@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); - - *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); - return Status::OK(); + return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 58240b9c965a194b9380ac7cd477ce7344e5ebe3..f7e34a5b40c2f9244c029ed325a76322b8cf54dd 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape); +// Converts a TensorShape into the equivalent XLA Shape proto, taking an +// xla::PrimitiveType to specify the element type. This never fails. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 5759c72af301785f3ca1110b58eeb2fe7dead713..8aae498be1042b5a55e849a03d438cd54dafca83 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,10 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -27,10 +26,10 @@ const char kShardingAttribute[] = "_XlaSharding"; } // namespace namespace { -xla::StatusOr> -GetShardingFromNodeDef(const NodeDef& node_def) { +xla::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def) { if (!HasNodeAttr(node_def, kShardingAttribute)) { - return tensorflow::gtl::optional(); + return absl::optional(); } string value; xla::OpSharding sharding; @@ -40,7 +39,7 @@ GetShardingFromNodeDef(const NodeDef& node_def) { "Experimental _XlaSharding attribute was not a valid encoded " "xla::OpSharding proto."); } - return tensorflow::gtl::optional(sharding); + return absl::optional(sharding); } Status CoreOutOfRangeError(int core, int num_cores_per_replica) { @@ -50,12 +49,11 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) { } } // namespace -xla::StatusOr> -ParseShardingFromDevice( +xla::StatusOr> ParseShardingFromDevice( const string& device_name, int num_cores_per_replica, - tensorflow::gtl::optional explicit_sharding) { + absl::optional explicit_sharding) { if (device_name.empty()) { - return tensorflow::gtl::optional(); + return absl::optional(); } DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { @@ -66,34 +64,34 @@ ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !str_util::StrContains(parsed_device.type, - kDeviceSuffixReplicatedCore)) { - return tensorflow::gtl::optional(); + !absl::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { + return absl::optional(); } else { const int core = parsed_device.id; if (core < 0 || core >= num_cores_per_replica) { return CoreOutOfRangeError(core, num_cores_per_replica); } - return tensorflow::gtl::optional( + return absl::optional( xla::sharding_builder::AssignDevice(core)); } } -xla::StatusOr> -ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { +xla::StatusOr> ParseShardingFromDevice( + const NodeDef& node_def, int num_cores_per_replica) { const string& device_name = node_def.device(); - TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + TF_ASSIGN_OR_RETURN(absl::optional sharding, GetShardingFromNodeDef(node_def)); return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } -xla::StatusOr> -ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { +xla::StatusOr> ParseShardingFromDevice( + const Node& node, int num_cores_per_replica) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } - TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + TF_ASSIGN_OR_RETURN(absl::optional sharding, GetShardingFromNodeDef(node.def())); return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index b1c817bdcc211648b16e395313ca171d1acb9ea9..ab67d4f154282e3fc37b68339045deb5da91b9db 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ -#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ #include @@ -33,19 +33,18 @@ namespace tensorflow { // - explicit_sharding if explicit_sharding.has_value() // - a non-value if there is no assigned core or // - a sharding set as per xla::sharding_builder::AssignDevice. -xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, - tensorflow::gtl::optional - explicit_sharding = tensorflow::gtl::nullopt); +xla::StatusOr> ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + absl::optional explicit_sharding = absl::nullopt); -xla::StatusOr> -ParseShardingFromDevice(const Node& node, int num_cores_per_replica); +xla::StatusOr> ParseShardingFromDevice( + const Node& node, int num_cores_per_replica); -xla::StatusOr> -ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); +xla::StatusOr> ParseShardingFromDevice( + const NodeDef& node_def, int num_cores_per_replica); void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc index bff5978237a827cb9650541f2cf6984d9e846796..dcb7e212b74d2e261de7e125bb66b3ec78e0cfe9 100644 --- a/tensorflow/compiler/tf2xla/sharding_util_test.cc +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -23,7 +23,7 @@ TEST(CoreUtilTest, ParseShardingFromDevice) { Graph graph(OpRegistry::Global()); auto core_from_sharding = - [](tensorflow::gtl::optional sharding) -> int64 { + [](absl::optional sharding) -> int64 { if (sharding.has_value() && sharding.value().type() == xla::OpSharding::Type::OpSharding_Type_MAXIMAL) { diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..6cd7b24592f30d7202b985f3dfd082ea2d85e344 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/side_effect_util.h" + +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; + +const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; + +std::set CalculateTokenInputsForOutputToken(const Graph& g) { + std::set results; + Node* first_side_effecting_node_on_path = nullptr; + ReverseDFS(g, + [&](Node* n) { + std::vector token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + results.insert(n->name()); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); + return results; +} + +bool HasSideEffectingNodes(const Graph& g) { + for (Node* n : g.nodes()) { + std::vector token_input_nodes; + if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) + .ok() && + !token_input_nodes.empty()) { + return true; + } + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 0000000000000000000000000000000000000000..ad07624729f0b0d2443b2fc43d32dfa3377ce115 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc deleted file mode 100644 index 2b0834fe7b6c4d2199267dbe0ec1f7c2785aa9c7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include -#include -#include - -namespace tensorflow { -namespace str_util { - -static void ReplaceAll(string* text, StringPiece from, StringPiece to) { - size_t pos = 0; - while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { - text->replace(pos, from.size(), to.data(), to.size()); - pos += to.size(); - if (from.empty()) { - pos++; // Match at the beginning of the text and after every byte - } - } -} - -void ReplaceAllPairs(string* text, - const std::vector>& replace) { - for (const std::pair& from_to : replace) { - ReplaceAll(text, from_to.first, from_to.second); - } -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h deleted file mode 100644 index 51f25009d7003db0d72296619a469ecbbbb1808d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// String utilities that are esoteric enough that they don't belong in -// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally -// useful under xla. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ -#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ - -#include -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace tensorflow { -namespace str_util { - -// Replace all non-overlapping occurrences of the given (from,to) pairs in-place -// in text. If from is empty, it matches at the beginning of the text and after -// every byte. Each (from,to) replacement pair is processed in the order it is -// given. -void ReplaceAllPairs(string* text, - const std::vector>& replace); - -} // namespace str_util -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc deleted file mode 100644 index 8817f6902a8e58e796ca5240a9a24d7506d38793..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace str_util { - -class ReplaceAllPairsTest : public ::testing::Test { - protected: - void ExpectReplaceAllPairs( - string text, const std::vector>& replace, - StringPiece want) { - ReplaceAllPairs(&text, replace); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllPairsTest, Simple) { - ExpectReplaceAllPairs("", {}, ""); - ExpectReplaceAllPairs("", {{"", ""}}, ""); - ExpectReplaceAllPairs("", {{"", "X"}}, "X"); - ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_"); - ExpectReplaceAllPairs("banana", {}, "banana"); - ExpectReplaceAllPairs("banana", {{"", ""}}, "banana"); - ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_"); - ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__"); - ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana"); - ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn"); - ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX"); - ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}", - {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}}, - "a0b123456789c0"); -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec..f31bfb45a2f4db270446eb59259969dc0ab63a8e 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } +std::unordered_map BuildNodeIndex(const Graph& graph) { + std::unordered_map index; + for (Node* node : graph.nodes()) { + index[node->name()] = node; + } + return index; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed23f3fca0f59b131dc73152e0947b72..350a868568531c0d073e0cf600327d1ff9d62e3a 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); +// Builds a map from node name to Node* for `graph`. +std::unordered_map BuildNodeIndex(const Graph& graph); + } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 48568c825b7a0f13011d3d6e8e62ec5db026760f..b22d53805d83069052cc5e16020d6c540d618a82 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,7 +22,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -40,8 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -75,7 +76,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, auto node_it = node_map.find(remap_it->second); if (node_it == node_map.end()) { // Strip off the aot_feed_#/ prefix. - StringPiece name(remap_it->second); + absl::string_view name(remap_it->second); const auto index = name.find('/'); if (index > 0) name.remove_prefix(index + 1); return errors::InvalidArgument( @@ -89,7 +90,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, // explicitly specify or override them. Node* arg_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) + NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp) .Attr("T", BaseType(feed_node->output_type(output_index))) .Attr("index", arg_index) .Attr(kFeedIdAttr, TensorIdToString(feed.id())) @@ -136,7 +137,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // Connects fetch_node -> retval_node. Node* retval_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) + NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp) .Input(fetch_node, id.output_index()) .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) .Attr("index", ret_index) @@ -197,8 +198,8 @@ Status RewriteAndPruneGraph( if (!missing_feeds.empty() || !missing_fetches.empty()) { return errors::Aborted( "Post graph-pruning", - ", missing feeds: ", str_util::Join(missing_feeds, ", "), - ", missing fetches: ", str_util::Join(missing_fetches, ", ")); + ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), + ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } return Status::OK(); } @@ -256,7 +257,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( - strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); + absl::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); @@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + + // Functionalize control flow. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); + // After control flow functionalization, we might have more FunctionDef's + // (then/else branch, loop body). Add them to the graph. + TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); + *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 7aca889a266439538c4cd1c153460e6cc871b246..567d212b5eee493d29a1817987cbd7759575386e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } std::sort(types.begin(), types.end()); constraints.push_back("`" + constraint.name() + "={" + - str_util::Join(types, ",") + "}`"); + absl::StrJoin(types, ",") + "}`"); } std::cout << "`" << kdef->op() << "` | " - << str_util::Join(constraints, "
") << std::endl; + << absl::StrJoin(constraints, "
") << std::endl; } std::cout << "\nTo regenerate this table, run:\n\n```shell\n" @@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + - str_util::Join(device_names, ",")}, + absl::StrJoin(device_names, ",")}, }; string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 56f7045a98201ed398244f9e3f5ff23788135b75..ab26d939ccba75ce58609ffd71c7ccadbe90cfa8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) { // Set up arguments. auto x_literal = xla::LiteralUtil::CreateR0(10); auto y_literal = xla::LiteralUtil::CreateR0(32); - auto x_global_or = client->TransferToServer(*x_literal); - auto y_global_or = client->TransferToServer(*y_literal); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); TF_EXPECT_OK(y_global_or.status()); std::unique_ptr x_global = @@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) { auto result_or = client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); - std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 9203e8d9e607e99ad738350a1c3f2b9e900df179..d6f42bac86f1ef359531d67b652d43d851d7ac02 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -16,23 +16,27 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include +#include #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -74,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace +const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; + Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { @@ -111,8 +117,8 @@ Status AddPlaceholdersForFeeds( const string name_port = TensorIdToString(feed->id()); PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; - info.placeholder_name = strings::StrCat( - "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); + info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(), + "/", feed->id().node_name()); (*feed_remapping)[name_port] = info.placeholder_name; } @@ -232,7 +238,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = std::string(id.first); + const string node_name = string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); @@ -257,7 +263,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, } string TensorIdToString(const tf2xla::TensorId& id) { - return strings::StrCat(id.node_name(), ":", id.output_index()); + return absl::StrCat(id.node_name(), ":", id.output_index()); } Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { @@ -267,7 +273,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { if (edge->IsControlEdge()) continue; const Node* possible_match = out_edges ? edge->dst() : edge->src(); TF_ASSIGN_OR_RETURN( - tensorflow::gtl::optional sharding, + absl::optional sharding, ParseShardingFromDevice( *possible_match, /*num_cores_per_replica=*/std::numeric_limits::max())); @@ -288,7 +294,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef) { for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { if (constraint.name() == name) { @@ -297,4 +303,126 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, } } +namespace { +uint32 InitialRandomSeed() { + // Support plumbing the TF seed through to XLA is being worked on. + // If a user wants deterministic behavior, their best option + // is to start with a known checkpoint. This also handles issues when + // multiple random calls can be invoked in any order by TF executor. + // Another option is to use stateless random ops. They have much cleaner + // semantics. + // If a user really wants to set a deterministic seed for XLA-based + // devices, this is the place to do it. + std::random_device rd; + // Make the starting value odd. + return rd() | 1; +} +} // namespace + +uint32 GetXLARandomSeed() { + // We initialize counter with an odd number and increment it by two + // everytime. This ensures that it will never be zero, even + // after an overflow. When seeded with zero, some XLA backends + // can return all zeros instead of random numbers. + static std::atomic counter(InitialRandomSeed()); + return counter.fetch_add(2); +} + +// TODO(b/77601805): add tests for associated function related stuff. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr) { + if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + return true; + } + + if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + return false; + } + + for (const auto& iter : node_def.attr()) { + if (iter.second.has_func()) { + return true; + } + } + + return false; +} + +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr) { + std::vector results; + const string& op = node.type_string(); + if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + // This is a function call node. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo(op, attrs)); + } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + } else { + // Collect all function attrs for the node. + for (auto& iter : node.attrs()) { + if (iter.second.has_func()) { + VLOG(2) << "Found function attr for node " << node.name() << ": " + << iter.first << " = " << iter.second.func().name(); + results.emplace_back(AssociatedFunctionInfo( + iter.second.func().name(), iter.second.func().attr(), iter.first)); + } + } + } + return results; +} + +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + // Change this node to call the new function. + NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + for (auto attr : node->attrs()) { + builder.Attr(attr.first, attr.second); + } + for (int i = 0; i < node->num_inputs(); i++) { + Node* input_node; + TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); + builder.Input(input_node->name(), i, node->input_type(i)); + } + builder.Device(node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name()); + NodeDef node_def; + TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); + Status s; + Node* new_node = graph->AddNode(node_def, &s); + TF_RETURN_IF_ERROR(s); + for (auto edge : node->in_edges()) { + graph->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input()); + } + for (auto edge : node->out_edges()) { + graph->AddEdge(new_node, edge->src_output(), edge->dst(), + edge->dst_input()); + } + graph->RemoveNode(node); + break; + } + case AssociatedFunctionInfo::kFunctionAttr: { + // Change function attr to rewritten functions. + NameAttrList func; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); + node->ClearAttr(associated_function.attr_name()); + func.set_name(rewritten_function_name); + node->AddAttr(associated_function.attr_name(), func); + break; + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 745beb39c1d917cd0d1cd219536ee26a96253ec9..6065d0bb9a3abd23b8911c5049914be8a5f23b99 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -53,9 +54,73 @@ string TensorIdToString(const tf2xla::TensorId& id); Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); // Add an allowed data type to the AttrConstraint with the given name. -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); +// Returns the next random seed to use for seeding xla rng. +uint32 GetXLARandomSeed(); + +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionCallNode = 0, + kFunctionAttr = 1, + }; + + // The node is a function call. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) + : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} + + // The function is an attr of the node. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, + const string& attr_name) + : type_(kFunctionAttr), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Attribute to mark nodes to be executed on host. +extern const char kXlaOutsideCompilationAttrName[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ae51446204baf14dc03fc6305641048dbf3872b0..68441b3d4790b17bd06accff3fcdc8ccee79bbb7 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -24,17 +27,14 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } @@ -153,7 +153,7 @@ static tf2xla::Config FetchesConfig(std::vector fetches) { tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); - fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); + fetch->set_name(absl::StrCat("fetch_", fetch_node_name)); fetch->mutable_id()->set_node_name(fetch_node_name); } return config; diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c969212a1bfaa6cab0d896ee074cfd4e2b283ae4..d00b1376620c0c9d112c7d7426758f6d3f25e86f 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { *type = xla::PRED; return Status::OK(); case tensorflow::DT_INT8: + case tensorflow::DT_QINT8: *type = xla::S8; return Status::OK(); case tensorflow::DT_INT16: + case tensorflow::DT_QINT16: *type = xla::S16; return Status::OK(); case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: *type = xla::S32; return Status::OK(); case tensorflow::DT_INT64: *type = xla::S64; return Status::OK(); case tensorflow::DT_UINT8: + case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); case tensorflow::DT_UINT16: + case tensorflow::DT_QUINT16: *type = xla::U16; return Status::OK(); case tensorflow::DT_UINT32: @@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); - case tensorflow::DT_QUINT8: - *type = xla::U8; - return Status::OK(); - case tensorflow::DT_QINT32: - *type = xla::S32; - return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index e89f4733281194f0263ae8cc4907caa0ad781165..7f860500c75667a920505dbf498e3da4b388fb90 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,12 +76,11 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, DeviceType type) - : LocalDevice( - options, - Device::BuildDeviceAttributes( - strings::StrCat("/device:", type.type(), ":0"), type, - Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type()))), + : LocalDevice(options, Device::BuildDeviceAttributes( + absl::StrCat("/device:", type.type(), ":0"), + type, Bytes(256 << 20), DeviceLocality(), + absl::StrCat("device: XLA compilation device ", + type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} @@ -103,7 +102,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, auto sharding_parse_result = ParseShardingFromDevice( op_kernel->def(), std::numeric_limits::max()); OP_REQUIRES_OK(context, sharding_parse_result.status()); - tensorflow::gtl::optional op_sharding = + absl::optional op_sharding = sharding_parse_result.ValueOrDie(); // If no sharding metadata is found, XLA is free to use whatever device it diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 334459138b55a201c15cb87ad9feb6a03a13c5ab..1f0f240135dfcd0c540cc39a42514c67ce979ee0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include @@ -22,61 +21,42 @@ namespace tensorflow { XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) - : raw_function_(static_data.raw_function), - result_index_(static_data.result_index), - args_(new void*[static_data.num_args]), - temps_(new void*[static_data.num_temps]), - arg_index_to_temp_index_(new int32[static_data.num_args]), - num_args_(static_data.num_args), - arg_names_(static_data.arg_names), - result_names_(static_data.result_names), - program_shape_(static_data.program_shape), - hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { + : raw_function_(static_data.raw_function_), + result_index_(static_data.result_index_), + buffer_table_(new void*[static_data.num_buffers_]), + buffer_infos_(static_data.buffer_infos_), + arg_index_table_(static_data.arg_index_table_), + num_args_(static_data.num_args_), + arg_names_(static_data.arg_names_), + result_names_(static_data.result_names_), + program_shape_(static_data.program_shape_), + hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { + bool allocate_entry_params = + alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { - alloc_args_ = cpu_function_runtime::MallocContiguousBuffers( - static_data.arg_sizes, static_data.num_args, args_, - /*annotate_initialized=*/false); - } - alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers( - static_data.temp_sizes, static_data.num_temps, temps_, + alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers( + static_data.buffer_infos_, static_data.num_buffers_, + /*allocate_entry_params=*/allocate_entry_params, buffer_table_, /*annotate_initialized=*/true); - - for (int i = 0; i < static_data.num_temps; i++) { - if (static_data.temp_sizes[i] < -1) { - int32 param_number = -(static_data.temp_sizes[i] + 2); - arg_index_to_temp_index_[param_number] = i; - } - } - // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function // signature, but it is ignored by the generated code and we pass in null for // it. if (hlo_profiling_enabled()) { - profile_counters_ = new int64[static_data.profile_counters_size](); + profile_counters_ = new int64[static_data.profile_counters_size_](); } } bool XlaCompiledCpuFunction::Run() { - // Propagate pointers to the argument buffers into the temps array. Code - // generated by XLA discovers the incoming argument pointers from the temps - // array. - for (int32 i = 0; i < num_args_; i++) { - temps_[arg_index_to_temp_index_[i]] = args_[i]; - } - raw_function_(temps_[result_index_], &run_options_, nullptr, temps_, - profile_counters_); + raw_function_(buffer_table_[result_index_], &run_options_, nullptr, + buffer_table_, profile_counters_); return true; } XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { - cpu_function_runtime::FreeContiguous(alloc_args_); - cpu_function_runtime::FreeContiguous(alloc_temps_); - delete[] args_; - delete[] temps_; - delete[] arg_index_to_temp_index_; + cpu_function_runtime::FreeContiguous(alloc_buffer_table_); + delete[] buffer_table_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 27cfb354bf5f8ede2dcca85065411006c352a575..425e769346ffcbc548495d93cb7adc779f860110 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -56,46 +57,85 @@ class XlaCompiledCpuFunction { // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for // AOT this is backed by data compiled into the object file. - struct StaticData { + // + // The contents of StaticData are XLA-internal implementation details and + // should not be relied on by clients. + // + // TODO(sanjoy): Come up with a cleaner way to express the contraint we want + // here: generated XlaCompiledCpuFunction subclasses should be able to create + // instances of StaticData but only XlaCompiledCpuFunction should be able to + // read from StaticData instances. + class StaticData { + public: + void set_raw_function(RawFunction raw_function) { + raw_function_ = raw_function; + } + void set_buffer_infos( + const cpu_function_runtime::BufferInfo* buffer_infos) { + buffer_infos_ = buffer_infos; + } + void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; } + void set_arg_index_table(const int32* arg_index_table) { + arg_index_table_ = arg_index_table; + } + void set_num_args(int64 num_args) { num_args_ = num_args; } + void set_result_index(size_t result_index) { result_index_ = result_index; } + void set_arg_names(const char** arg_names) { arg_names_ = arg_names; } + void set_result_names(const char** result_names) { + result_names_ = result_names; + } + void set_program_shape(const xla::ProgramShape* program_shape) { + program_shape_ = program_shape; + } + const xla::HloProfilePrinterData* hlo_profile_printer_data() const { + return hlo_profile_printer_data_; + } + void set_hlo_profile_printer_data( + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + hlo_profile_printer_data_ = hlo_profile_printer_data; + } + void set_profile_counters_size(int64 profile_counters_size) { + profile_counters_size_ = profile_counters_size; + } + + private: // The raw function to call. - RawFunction raw_function; - - // Cardinality and size of arg buffers. - const intptr_t* arg_sizes = nullptr; - size_t num_args = 0; - - // Cardinality and size of temp buffers. - // - // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. - // - // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The - // corresponding entry in the temp buffer array needs to be set to null. - // - // If temp_sizes[i] < -1 then the i'th temp is the entry parameter - // -(temp_sizes[i] + 2). - const intptr_t* temp_sizes = nullptr; - size_t num_temps = 0; + RawFunction raw_function_; + + // Contains information about the buffers used by the XLA computation. + const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; + size_t num_buffers_ = 0; + + // Entry parameter i is described by + // buffer_infos[arg_index_table[i]]. + const int32* arg_index_table_ = nullptr; + + // There are num_args entry parameters. + int64 num_args_ = 0; // The 0-based index of the result tuple, in the temp buffers. - size_t result_index = 0; + size_t result_index_ = 0; // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. - const char** arg_names = nullptr; - const char** result_names = nullptr; + const char** arg_names_ = nullptr; + const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape = nullptr; + const xla::ProgramShape* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. - const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; // [Optional] The number of profile counters expected in the profile counter // buffer by the generated code and hlo_profile_printer. 0 if profiling is // disabled. This information is already present in // hlo_profile_printer_data but xla::HloProfilePrinterData is forward // declared so we don't have access to that information here. - int64 profile_counters_size = 0; + int64 profile_counters_size_ = 0; + + // Only XlaCompiledCpuFunction is allowed to read the above fields. + friend class XlaCompiledCpuFunction; }; // AllocMode controls the buffer allocation mode. @@ -135,14 +175,25 @@ class XlaCompiledCpuFunction { // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. - // Returns the underlying array of argument buffers, where args()[I] is the - // buffer for the positional argument at index I. - void** args() { return args_; } - const void* const* args() const { return args_; } - // Returns the buffer for the positional argument at the given `index`. - void* arg_data(size_t index) { return args_[index]; } - const void* arg_data(size_t index) const { return args_[index]; } + void* arg_data(size_t index) { + return buffer_table_[arg_index_table_[index]]; + } + const void* arg_data(size_t index) const { + return buffer_table_[arg_index_table_[index]]; + } + + int num_args() const { return num_args_; } + + // Returns the size of entry parameter `idx`. + // + // There is a static version of this method on tfcompile generated subclasses + // of XlaCompiledCpuFunction, but try to prefer this when possible since it + // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses. + int arg_size(int idx) const { + assert(idx < num_args()); + return buffer_infos_[arg_index_table_[idx]].size(); + } // Sets the buffer for the positional argument at the given `index` to `data`. // Must be called before Run to have an effect. May be called under any @@ -155,7 +206,9 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { args_[index] = data; } + void set_arg_data(size_t index, void* data) { + buffer_table_[arg_index_table_[index]] = data; + } // ------------------------------ // Result methods for managing output buffers. Buffers are in row-major order. @@ -165,9 +218,9 @@ class XlaCompiledCpuFunction { // Returns the underlying array of result buffers, where results()[I] is the // buffer for the positional result at index I. - void** results() { return static_cast(temps_[result_index_]); } + void** results() { return static_cast(buffer_table_[result_index_]); } const void* const* results() const { - return static_cast(temps_[result_index_]); + return static_cast(buffer_table_[result_index_]); } // Profile counters for this XLA computation. @@ -225,25 +278,28 @@ class XlaCompiledCpuFunction { const RawFunction raw_function_; const size_t result_index_; - // Arrays of argument and temp buffers; entries in args_ may be overwritten by - // the user. - void** args_ = nullptr; - void** temps_ = nullptr; + // Array containing pointers to argument and temp buffers (slots corresponding + // to constant and on-stack buffers are null). + void** const buffer_table_; + + // Describes the buffers used by the XLA computation. + const cpu_function_runtime::BufferInfo* const buffer_infos_; - // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for - // XLA generated code to be able to find it. + // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] + // for XLA generated code to be able to find it. // // For now we need to keep around the args_ array because there is code that // depends on args() returning a void**. However, in the future we may remove - // args_ in favor of using temps_ as the sole storage for the arguments. - int32* arg_index_to_temp_index_; + // args_ in favor of using buffer_table_ as the sole storage for the + // arguments. + const int32* const arg_index_table_; // The number of incoming arguments. - int32 num_args_; + const int32 num_args_; - // Backing memory for individual arg and temp buffers. - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; + // Backing memory for buffer_table_ and args_, the latter depending on + // AllocMode. + void* alloc_buffer_table_ = nullptr; // Backing memory for profiling counters. int64* profile_counters_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 226c89bcf1e66b5afb43cddb03db39b931ca55a8..d5094e8ec5ed95b8cdbad63762a7fbc718ba5f30 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" @@ -148,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); + VLOG(4) << "Function " << function.name() << " in flib_runtime_"; + } else { + VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -197,14 +201,14 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == "_Arg") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == "_Retval") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -212,8 +216,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_", function_id), - *graph); + absl::StrCat("xla_compile_function_", function_id), *graph); } VLOG(1) << "===================================================="; @@ -291,6 +294,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, "Invalid resource type in XLAShapeForArgument()"); } } + case XlaCompiler::Argument::kToken: { + *xla_shape = xla::ShapeUtil::MakeTokenShape(); + return Status::OK(); + } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); } @@ -310,7 +317,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // unique_ptr so we can capture the cleanup status in the end. xla_context->Ref(); Status status; - auto step_container = xla::MakeUnique( + auto step_container = absl::make_unique( step_id, [&status, device](const string& name) { status = device->resource_manager()->Cleanup(name); }); @@ -318,8 +325,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, - step_container.get()); + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); @@ -327,10 +333,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, } // Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. +// `args` is the list of input arguments, `retvals` is the list of retvals +// produced by _Retval operators, in index order. // If `return_updated_values_for_all_resources` is true, all resources will be // included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. @@ -360,6 +364,9 @@ Status BuildComputation( if (retval.has_constant_value()) { output.is_constant = true; output.constant_value = retval.constant_value(); + } else if (retval.resource() != nullptr) { + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); } else { output.is_constant = false; elems.push_back(retval.handle()); @@ -413,7 +420,7 @@ Status BuildComputation( // Request that the value be returned on a specific core. xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() + builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); xla::XlaOp handle; @@ -464,8 +471,6 @@ Status XlaCompiler::BuildArguments( // XLA computation as runtime parameters. input_mapping->clear(); input_mapping->reserve(args.size()); - std::vector resources; - resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -484,10 +489,12 @@ Status XlaCompiler::BuildArguments( /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { - resources.push_back(i); + input_mapping->push_back(i); } + break; - case XlaCompiler::Argument::kParameter: { + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } @@ -495,14 +502,11 @@ Status XlaCompiler::BuildArguments( arg_expression.set_constant_value(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling constant args"); } } - // Append parameters containing variable values after the other runtime - // parameters. - input_mapping->insert(input_mapping->end(), resources.begin(), - resources.end()); if (input_mapping->empty()) { return Status::OK(); } @@ -522,7 +526,7 @@ Status XlaCompiler::BuildArguments( // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { - if (StringPiece(n->type_string()) != "_Arg") continue; + if (absl::string_view(n->type_string()) != "_Arg") continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0 && index < args.size()) @@ -570,7 +574,7 @@ Status XlaCompiler::BuildArguments( for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() + builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::GetTupleElement(tuple, i); } @@ -578,10 +582,10 @@ Status XlaCompiler::BuildArguments( for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() + builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], - strings::StrCat("arg", i)); + absl::StrCat("arg", i)); } } @@ -617,9 +621,14 @@ Status XlaCompiler::BuildArguments( arg_expression.set_handle(arg_handles[i]); } break; + case XlaCompiler::Argument::kToken: { + arg_expression.set_handle(arg_handles[i]); + break; + } case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling handles"); } } @@ -643,7 +652,7 @@ Status XlaCompiler::CompileSingleOp( // dependency edge to the _SOURCE node. for (int64 i = 0; i < ctx->num_inputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); Status status = NodeBuilder(name, "_Arg") .ControlInput(graph->source_node()) .Attr("T", ctx->input_dtype(i)) @@ -656,7 +665,7 @@ Status XlaCompiler::CompileSingleOp( // Similarly with return values, create dummy _Retval nodes fed by `node`. for (int64 i = 0; i < ctx->num_outputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); Status status = NodeBuilder(name, "_Retval") .Input(main_node, i) .Attr("T", ctx->expected_output_dtype(i)) @@ -692,7 +701,7 @@ Status ValidateGraph(const Graph* graph, const DeviceType& device_type, const string& name) { auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", s.error_message(), ")", FormatNodeForError(*node))); @@ -733,18 +742,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); - // Converts Tensorflow's graph control-flow constructs into functional - // control-flow that can be compiled into XLA code. - TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), - graph.get(), local_flib_def_.get())); - // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, @@ -757,23 +761,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, &options_.shape_representation_fn); core::ScopedUnref context_unref(context); + std::vector real_args(args); + int token_input_index = -1; + if (options.add_token_input_output) { + // Add extra token input. + token_input_index = real_args.size(); + + XlaCompiler::Argument token_arg; + token_arg.kind = XlaCompiler::Argument::kToken; + real_args.push_back(token_arg); + } + std::vector arg_expressions; std::vector arg_cores; - TF_RETURN_IF_ERROR( - BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, - &arg_cores, &arg_expressions, &result->input_mapping, - &result->xla_input_shapes, options.is_entry_computation)); + TF_RETURN_IF_ERROR(BuildArguments( + *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes, + options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + PushNodeTokenMapping(); + // Use std::set instead of std::unordered_set to ensure determinism. + std::set output_node_token_inputs; + if (token_input_index != -1) { + // Original token comes from input. + auto arg_expression = context->args()[token_input_index]; + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); + + // Calculate token inputs for output token. + output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); + + // If there's no side-effecting op in the graph, use token input as token + // output. + if (output_node_token_inputs.empty()) { + output_node_token_inputs.insert(kXlaTokenArgNodeName); + } + } else if (options.is_entry_computation) { + // Original token is manually created. + if (HasSideEffectingNodes(*graph)) { + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + } + } + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, flib_runtime_, NextStepId())); + if (token_input_index != -1) { + // Add extra token output. + std::vector token_inputs; + for (const auto& node_name : output_node_token_inputs) { + auto token_or = GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + TF_RETURN_IF_ERROR( + context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + } + TF_RETURN_IF_ERROR(PopNodeTokenMapping()); int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( - args, arg_cores, context->retvals(), context->resources(), + real_args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, @@ -791,14 +843,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); - // Copy the host transfer metadata to the result. - for (const auto& send : host_compute_sends_) { - *result->host_compute_metadata.add_device_to_host() = send.second; - } - for (const auto& recv : host_compute_recvs_) { - *result->host_compute_metadata.add_host_to_device() = recv.second; - } - // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); @@ -816,10 +860,34 @@ Status XlaCompiler::GetChannelHandle(const string& key, return Status::OK(); } +Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel) { + auto result = channels_.emplace(key, xla::ChannelHandle()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + client()->CreateHostToDeviceChannelHandle()); + } + *channel = result.first->second; + VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); + return Status::OK(); +} + +Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel) { + auto result = channels_.emplace(key, xla::ChannelHandle()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + client()->CreateDeviceToHostChannelHandle()); + } + *channel = result.first->second; + VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); + return Status::OK(); +} + namespace { -void SetTransfer(const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes, +void SetTransfer(const string& key, absl::Span types, + absl::Span shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); @@ -833,8 +901,8 @@ void SetTransfer(const string& key, gtl::ArraySlice types, } // namespace Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -860,8 +928,8 @@ Status XlaCompiler::GetDeviceToHostShapes( } Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); @@ -896,4 +964,47 @@ Status XlaCompiler::SetHostComputeControlDependency( return Status::OK(); } +void XlaCompiler::PushNodeTokenMapping() { + node_token_mapping_stack_.emplace(std::map{}); +} + +Status XlaCompiler::PopNodeTokenMapping() { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " + "empty."); + } + node_token_mapping_stack_.pop(); + return Status::OK(); +} + +Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp& op) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling SetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); + if (!insert_result.second) { + return errors::FailedPrecondition("Token mapping already exists for node ", + node_name); + } + return Status::OK(); +} + +xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling GetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto iter = node_token_mapping_stack_.top().find(node_name); + if (iter == node_token_mapping_stack_.top().end()) { + return errors::FailedPrecondition("Cannot find token mapping for node ", + node_name); + } + return iter->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 25332c8d8e3210a0217a1ba3f5767115fe6b1d93..2cc603a58016a509fafdf6f95423dd6c0864cce3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,10 +185,15 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. DataType type; TensorShape shape; @@ -190,6 +201,10 @@ class XlaCompiler { // 'Tensor' is in host memory. bool is_constant = false; Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; }; // Describes a variable write side effect of the computation. @@ -212,9 +227,9 @@ class XlaCompiler { struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their - // original argument positions. To handle compile-time constant inputs and - // resources, the parameters to the XLA computation may be a subset of the - // original arguments, and are not necessarily in the same order.) + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. std::vector input_mapping; // Input shapes of the computation. If we are flattening inputs, these are @@ -332,11 +347,21 @@ class XlaCompiler { // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + // Retrieves the host-to-device channel handle associated with `key`. + // Allocates a new channel handle if none exists. + Status GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel); + + // Retrieves the device-to-host channel handle associated with `key`. + // Allocates a new channel handle if none exists. + Status GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel); + // Sets the shapes and types for the device to host transfer associated with // 'key'. Status SetDeviceToHostMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. Status GetDeviceToHostShapes(const string& key, @@ -345,8 +370,8 @@ class XlaCompiler { // Sets the shapes and types for the host to device transfer associated with // 'key'. Status SetHostToDeviceMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // In order to avoid deadlocks from dependencies in host computations, it can // be necessary to enforce a partial order on the execution of HostCompute @@ -368,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -432,6 +462,15 @@ class XlaCompiler { std::unordered_map host_compute_control_output_; + // This is used to store mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index be00ed8813fdf2778d6af81556001ef51538dd34..72b17d04fc42eb00781e96b412465b73fb29a5c2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -14,15 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -38,7 +42,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -205,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -261,23 +259,68 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); +} + +// Tests that the compiler doesn't reorder the parameters. +TEST_F(XlaCompilerTest, MixedOrderArguments) { + for (bool swap_order : {false, true}) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = + ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0); + // Adds an identity op around the resource to make sure identity ops + // propagate resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + if (swap_order) { + // Even after swapping arguments, the compiler should maintain the new + // ordering of parameters. + std::swap(args[0], args[1]); + } + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1)); + } } TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { @@ -309,10 +352,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "depends on a parameter")) + absl::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) + absl::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -357,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -392,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -568,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { auto instr1 = c1.instructions(j); auto instr2 = c2.instructions(j); instr1.clear_name(); + instr1.clear_id(); + instr1.clear_operand_ids(); instr2.clear_name(); - // The names of instructions were uniquified by the XlaBuilder, the rest - // of the fields should be identical. + instr2.clear_id(); + instr2.clear_operand_ids(); + // The names of instructions were uniquified by the XlaBuilder and the + // unique ids may be different, the rest of the fields should be + // identical. string str1, str2; + LOG(INFO) << "instr1 = " << instr1.DebugString(); + LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); instr2.AppendPartialToString(&str2); EXPECT_EQ(str1, str2); @@ -621,34 +664,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr input_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr input_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr output_read = - xla::LiteralUtil::CreateR0(42); - std::unique_ptr output_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr output_grad1 = - xla::LiteralUtil::CreateR1({0, 1}); - std::unique_ptr output_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -727,8 +762,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); } @@ -807,21 +841,44 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "Attr T is not found")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found")) << status.error_message(); } +void RunAndCheckVariablesComputation( + xla::Client* client, const XlaCompiler::CompilationResult& result) { + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client->TransferToServer(param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client->TransferToServer(param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, Variables) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); - auto write = ops::AssignAddVariableOp(scope, var, a); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); @@ -844,34 +901,85 @@ TEST_F(XlaCompilerTest, Variables) { // Compiles the graph. XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); +} + +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0); + auto d = ops::_Retval(scope.WithOpName("D"), var, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kVariable; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_ - ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); +} + +TEST_F(XlaCompilerTest, ReturnResourceHandle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto r = ops::_Retval(scope.WithOpName("R"), var, 0); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); } xla::StatusOr> BuildTestGraph() { @@ -937,29 +1045,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1006,29 +1112,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({4, 55, 1, -3}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. @@ -1075,9 +1178,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1100,10 +1203,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "is not in the list of allowed values")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "is not in the list of allowed values")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}")) << status.error_message(); } @@ -1123,25 +1226,73 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) - << status.error_message(); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); } +} + +class DummySideEffectingOp : public XlaOpKernel { + public: + explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken( + name(), xla::CreateToken(ctx->builder()))); + } +}; + +REGISTER_OP("DummySideEffectingOp"); + +REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp); - // Fix control edges for NoOp. +TEST_F(XlaCompilerTest, TokenInputAndOutput) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef side_effecting_op; + side_effecting_op.set_name("DummySideEffectingOp"); + side_effecting_op.set_op("DummySideEffectingOp"); + AddNodeAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}, &side_effecting_op); + Status status; + graph->AddNode(side_effecting_op, &status); + TF_ASSERT_OK(status); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); + + const std::vector empty_args; { + // The case for entry computation: we don't add token input/output. Instead, + // we use CreateToken HLO to create the entry token. + XlaCompiler::CompileOptions options; + options.is_entry_computation = true; + options.add_token_input_output = false; + XlaCompiler compiler(DefaultOptions()); + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); - EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result)); - EXPECT_EQ(0, result.resource_updates.size()); + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 0); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + } + { + // The case for non-entry computation (e.g. while loop body). We add token + // input/output. + XlaCompiler::CompileOptions options; + options.is_entry_computation = false; + options.add_token_input_output = true; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken( + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5..f247570d72c0287a33695de3d778cce2a2418921 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -107,6 +106,30 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, return Status::OK(); } +Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { + VLOG(1) << "Adding retval index " << retval_index << " with resource " + << resource->name() << ":" << resource->shape().DebugString() + << " to XLA computation"; + if (retvals_.size() <= retval_index) { + retvals_.resize(retval_index + 1); + } + XlaExpression e; + e.set_resource(resource); + retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; + return Status::OK(); +} + +Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { + VLOG(1) << "Adding retval index " << retvals_.size() + << " with token to XLA computation"; + XlaExpression e; + e.set_handle(token); + // We use DT_INVALID because there is no TF DataType which corresponds to XLA + // token. XlaCompiler handles this case separately, so putting it here is OK. + retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3db37afdba71342cfb20af8841a40cb54709ca73..d7dbdc957f0e7969db5098b815381866cdc71ab6 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -86,6 +86,12 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::LiteralSlice& literal); + // As for Retval, but for return values that are resource handles. + Status AddResourceRetval(int retval_index, XlaResource* resource); + + // As for Retval, but for return values that are XLA tokens. + Status AppendTokenRetval(const xla::XlaOp& token); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index 23d04d43b358e858ad1ab2463322ce0ab93b23c2..bc44301d405102921de21da4bd9407032783838c 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -20,21 +20,6 @@ limitations under the License. namespace tensorflow { bool CpuOpFilter(KernelDef* kdef) { - // TODO(b/34339814): implement inverse erf for double types and remove this - // workaround. - if (kdef->op() == "RandomStandardNormal") { - kdef->clear_constraint(); - // Change the type constraint to permit only DTD_FLOAT. - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name("dtype"); - attr_constraint->mutable_allowed_values()->mutable_list()->add_type( - DT_FLOAT); - return true; - } - // TODO(b/26783907): The CPU backend currently does not implement sort. - if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 8efb3d55c88757b9366bdf9622287bdd0a72e295..9a34cd8c6ae2dc6d52a3cc69168df96f5322c6da 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, } /* static */ Status XlaHelpers::ReshapeLiteral( - const xla::Literal& input, gtl::ArraySlice dimensions, + const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { if (xla::ShapeUtil::IsTuple(input.shape())) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index e6522157a535fc3e4ec96cb0496b6be2e525c336..39578144caaadf293d24ea91aa874e56e27ecc01 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -18,10 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -50,7 +50,7 @@ class XlaHelpers { // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. static Status ReshapeLiteral(const xla::Literal& input, - gtl::ArraySlice shape, + absl::Span shape, xla::Literal* output); // Returns the argmax of `input` along `axis`. `output_type` is the type to diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 114a9241bdb00526df76478b030a9efa506dd29c..86a78ee429e8913edb4a948727fa692083c472f4 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,45 +36,6 @@ limitations under the License. namespace tensorflow { namespace { - -// Returns a vector of positional argument buffer sizes. -xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape) { - std::vector arg_sizes; - const size_t num_args = program_shape.parameters_size(); - arg_sizes.reserve(num_args); - for (int i = 0; i < num_args; ++i) { - const xla::Shape& arg_shape = program_shape.parameters(i); - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } - return std::move(arg_sizes); -} - -// Returns a vector of positional temporary buffer sizes. -xla::StatusOr> ComputeTempSizes( - const xla::BufferAssignment& buffer_assignment) { - const std::vector& allocations = - buffer_assignment.Allocations(); - std::vector temp_sizes; - temp_sizes.reserve(allocations.size()); - for (const xla::BufferAllocation& allocation : allocations) { - if (allocation.is_constant() || allocation.is_thread_local()) { - // Constants are lowered to globals. Thread locals are lowered to - // allocas. - temp_sizes.push_back(-1); - } else if (allocation.is_entry_computation_parameter()) { - // Entry computation parameters need some preprocessing in - // XlaCompiledCpuFunction::Run. See the comment on - // XlaCompiledCpuFunction::StaticData::temp_sizes. - temp_sizes.push_back(-allocation.parameter_number() - 2); - } else { - temp_sizes.push_back(allocation.size()); - } - } - return std::move(temp_sizes); -} - // Returns the index of the result in the temp buffers. xla::StatusOr ComputeResultIndex( const xla::BufferAssignment& buffer_assignment) { @@ -157,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile( const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); - // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN(std::vector arg_sizes, - ComputeArgSizes(*program_shape)); - TF_ASSIGN_OR_RETURN(std::vector temp_sizes, - ComputeTempSizes(buffer_assignment)); + // Compute buffer infos and the result index, needed to run the raw function. + std::vector buffer_infos = + xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment); + std::vector arg_index_table = + xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); @@ -169,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile( new XlaJitCompiledCpuFunction); XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); jit->executable_ = std::move(executable); - jit->arg_sizes_ = std::move(arg_sizes); - jit->temp_sizes_ = std::move(temp_sizes); + jit->buffer_infos_ = std::move(buffer_infos); + jit->arg_index_table_ = std::move(arg_index_table); jit->program_shape_ = std::move(program_shape); - jit->static_data_.raw_function = std::move(raw_function); - jit->static_data_.arg_sizes = jit->arg_sizes_.data(); - jit->static_data_.num_args = jit->arg_sizes_.size(); - jit->static_data_.temp_sizes = jit->temp_sizes_.data(); - jit->static_data_.num_temps = jit->temp_sizes_.size(); - jit->static_data_.result_index = result_index; + jit->static_data_.set_raw_function(raw_function); + jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); + jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); + jit->static_data_.set_arg_index_table(jit->arg_index_table_.data()); + jit->static_data_.set_num_args(jit->arg_index_table_.size()); + jit->static_data_.set_result_index(result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); - jit->static_data_.arg_names = jit->arg_names_.data(); - jit->static_data_.result_names = jit->result_names_.data(); - jit->static_data_.program_shape = jit->program_shape_.get(); + jit->static_data_.set_arg_names(jit->arg_names_.data()); + jit->static_data_.set_result_names(jit->result_names_.data()); + jit->static_data_.set_program_shape(jit->program_shape_.get()); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.hlo_profile_printer_data = - &cpu_executable->hlo_profile_printer_data(); - jit->static_data_.profile_counters_size = - cpu_executable->hlo_profile_printer_data().profile_counters_size(); + jit->static_data_.set_hlo_profile_printer_data( + &cpu_executable->hlo_profile_printer_data()); + jit->static_data_.set_profile_counters_size( + cpu_executable->hlo_profile_printer_data().profile_counters_size()); } return std::move(jit_unique_ptr); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index af307ae4eff74927242c4650d8a43710e991cc52..d3c8f22a8078d03d15447ed200c914390f40b04f 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction { // The static data is backed by the rest of the state in this class. XlaCompiledCpuFunction::StaticData static_data_; - // The backing arrays of arg and temp buffer sizes. - std::vector arg_sizes_; - std::vector temp_sizes_; + // The backing array for buffer infos. + std::vector buffer_infos_; + + // The backing array for the arg index table. + std::vector arg_index_table_; // The backing arrays of arg and result names. We hold the actual strings in // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 82028c8b9ca9f65a73f8b50edc0a47c7068aba9a..2a9eaeee146bf6d792e010df7e041f9986b2c77e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -67,7 +67,7 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) { +const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { return GetComputationFromTensor(GetInputTensorByName(name)); } @@ -75,7 +75,7 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } -TensorShape XlaOpKernelContext::InputShape(StringPiece name) { +TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { return GetInputTensorByName(name).shape(); } @@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const { return context_->input(index).dtype(); } +DataType XlaOpKernelContext::InputType(absl::string_view name) { + return GetInputTensorByName(name).dtype(); +} + xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; Status status = DataTypeToPrimitiveType(input_type(index), &type); @@ -99,8 +103,27 @@ Status XlaOpKernelContext::ConstantInput(int index, index, context_->input(index).shape().dim_sizes(), constant_literal); } +static xla::StatusOr InputIndex(XlaOpKernelContext* context, + absl::string_view name) { + int start, stop; + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + return start; +} + +Status XlaOpKernelContext::ConstantInput(absl::string_view name, + xla::Literal* constant_literal) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInput(index, constant_literal); +} + Status XlaOpKernelContext::ConstantInputReshaped( - int index, gtl::ArraySlice new_dims, + int index, absl::Span new_dims, xla::Literal* constant_literal) { const Tensor& tensor = context_->input(index); TensorShape new_shape(new_dims); @@ -194,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, - "as a compile-time constant.\nError: ", + " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } @@ -246,6 +268,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name, + int64* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntScalar(index, out); +} + Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); @@ -280,6 +308,20 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name, + std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntVector(index, out); +} + +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + int index, std::vector* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -305,6 +347,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name, + xla::Literal* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsInt64Literal(index, out); +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { @@ -316,7 +364,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList(StringPiece name, +Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { OpInputList inputs; @@ -331,7 +379,7 @@ Status XlaOpKernelContext::InputList(StringPiece name, } Status XlaOpKernelContext::ConstantInputList( - StringPiece name, std::vector* outputs) { + absl::string_view name, std::vector* outputs) { int start, stop; TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); @@ -384,8 +432,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, value); } -Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type, - TensorShape* shape, +Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, + DataType type, TensorShape* shape, xla::XlaOp* value) { return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, shape, value); @@ -519,7 +567,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, handle, builder()); } -Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type, +Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); return AssignVariableTensor(GetInputTensorByName(name), type, context_, @@ -565,7 +613,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( return XlaContext::Get(context_).GetOrCreateMul(type); } -const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) { +const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; CHECK(context_->input(name, &tensor).ok()); return *tensor; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index ac9dfe3369078df7392a4ef04679f7d7beacf8bb..a3a0d10cc06cd4afceec728b7dbe287389099b9d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -71,6 +71,9 @@ class XlaOpKernelContext { // Returns the type of input `index`. DataType input_type(int index) const; + // Returns the type of input `name`. + DataType InputType(absl::string_view name); + // Returns the type of input `index` as an xla::PrimitiveType. If the type // is not representable as an XLA type, sets an error status and returns // xla::PRIMITIVE_TYPE_INVALID. @@ -79,15 +82,15 @@ class XlaOpKernelContext { // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns the shape of input `name`. - TensorShape InputShape(StringPiece name); + // Returns the shape of input with name `name`. + TensorShape InputShape(absl::string_view name); // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(StringPiece name); + const xla::XlaOp& Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -97,7 +100,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, std::vector* handles, + Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -106,26 +109,35 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); + Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, gtl::ArraySlice new_shape, + Status ConstantInputReshaped(int index, absl::Span new_dims, xla::Literal* constant_literal); // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); + Status ConstantInputAsIntScalar(absl::string_view name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); + Status ConstantInputAsIntVector(absl::string_view name, + std::vector* out); + + // Reshapes and converts a constant int32 or int64 tensor into a vector of + // int64s. + Status ConstantInputReshapedToIntVector(int index, std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); @@ -133,7 +145,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status ConstantInputList(StringPiece name, + Status ConstantInputList(absl::string_view name, std::vector* literals); // Outputs @@ -182,8 +194,8 @@ class XlaOpKernelContext { xla::XlaOp* value); // Reads the current value of the resouce variable referred to by input // `name`. - Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, - xla::XlaOp* value); + Status ReadVariableInput(absl::string_view name, DataType type, + TensorShape* shape, xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the @@ -191,7 +203,8 @@ class XlaOpKernelContext { // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Assigns the value `handle` to the variable referenced by input `name`. - Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); + Status AssignVariable(absl::string_view name, DataType type, + xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -240,7 +253,7 @@ class XlaOpKernelContext { private: // Returns the tensor of input `name`. - const Tensor& GetInputTensorByName(StringPiece name); + const Tensor& GetInputTensorByName(absl::string_view name); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 46785bc1f0a1279bfd67a55844fe238d9797382b..91d48125f1d21092db7e5f9307e44af9c16e4e2b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible compile time constant inputs."; return false; } + if (x.is_metadata_op != y.is_metadata_op) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible values for is_metadata_op."; + return false; + } return true; } @@ -105,7 +110,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; /* static */ void XlaOpRegistry::RegisterBackend( const string& compilation_device_name, - gtl::ArraySlice supported_types, BackendOpFilter op_filter) { + absl::Span supported_types, BackendOpFilter op_filter) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto result = registry.backends_.emplace(compilation_device_name, Backend()); @@ -325,6 +330,17 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { + std::vector ops; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& pair : registry.ops_) { + ops.push_back(pair.first); + } + std::sort(ops.begin(), ops.end()); + return ops; +} + /* static */ const std::unordered_set* XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); @@ -339,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { return &it->second.front()->compile_time_constant_inputs; } +/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { + return false; + } + + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // is_metadata_op, so only the first match is returned. + return it->second.front()->is_metadata_op; +} + std::vector XlaOpRegistry::BackendNames() { std::vector names; XlaOpRegistry& registry = Instance(); @@ -360,28 +390,30 @@ XlaOpRegistry& XlaOpRegistry::Instance() { return *r; } -XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { +XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = std::string(name); + registration_->name = string(name); } -XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { +XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( + absl::string_view name) { XlaOpRegistrationBuilder registration(name); return registration; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( - gtl::ArraySlice devices) { + absl::Span devices) { registration_->has_device_whitelist = true; - for (StringPiece device : devices) { - registration_->device_whitelist.insert(std::string(device)); + for (absl::string_view device : devices) { + registration_->device_whitelist.emplace(device); } return *this; } -XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( + absl::string_view device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(std::string(device)); + registration_->device_whitelist.emplace(device); return *this; } @@ -396,17 +428,17 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, DataType allowed) { + absl::string_view attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; types.insert(allowed); return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, gtl::ArraySlice allowed) { + absl::string_view attr_name, absl::Span allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -414,8 +446,13 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( - StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(std::string(input_name)); + absl::string_view input_name) { + registration_->compile_time_constant_inputs.emplace(input_name); + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() { + registration_->is_metadata_op = true; return *this; } @@ -441,10 +478,10 @@ XlaOpRegistrar::XlaOpRegistrar( } XlaBackendRegistrar::XlaBackendRegistrar( - StringPiece name, gtl::ArraySlice types, + absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(std::string(name), types, op_filter); + registry.RegisterBackend(string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index fc14834ca6441ea785eacc57e1f502086f36657e..4b2c2bacd647b3e6fe500a942b116772550195ce 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { + {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. @@ -94,7 +95,7 @@ class XlaOpRegistry { // the device; it may optionally modify the KernelDef. typedef bool (*BackendOpFilter)(KernelDef* kdef); static void RegisterBackend(const string& compilation_device_name, - gtl::ArraySlice supported_types, + absl::Span supported_types, BackendOpFilter op_filter); // Returns the names of the registered backends. @@ -128,11 +129,18 @@ class XlaOpRegistry { const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns all operations for which there are XLA kernels on any device. + static std::vector GetAllRegisteredOps(); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr // if the op is not registered. static const std::unordered_set* CompileTimeConstantInputs( const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes + // of its operands and not their values. + static bool IsMetadataOp(const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -189,6 +197,10 @@ class XlaOpRegistry { // Names of arguments that must be compile-time constants. std::unordered_set compile_time_constant_inputs; + // True if this is a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + bool is_metadata_op = false; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -229,19 +241,19 @@ class XlaOpRegistry { class XlaOpRegistrationBuilder { public: // Starts an operator registration chain. - static XlaOpRegistrationBuilder Name(StringPiece name); + static XlaOpRegistrationBuilder Name(absl::string_view name); // Specifies a whitelist of devices on which the operator may run. - XlaOpRegistrationBuilder& Device(StringPiece devices); - XlaOpRegistrationBuilder& Device(gtl::ArraySlice devices); + XlaOpRegistrationBuilder& Device(absl::string_view devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); // Specifies a type constraint for a type variable attribute. Each constraint // specifies the set of types that the type variable may assume. - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, DataType allowed); - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, - gtl::ArraySlice allowed); + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, + absl::Span allowed); // Specifies that a dummy copy of this operator should not be registered on // XLA_* devices, but may be used during compilation. @@ -251,13 +263,17 @@ class XlaOpRegistrationBuilder { XlaOpRegistrationBuilder& AllowResourceTypes(); // Mark 'input_name' as an argument whose value must be known at compile-time. - XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); + + // Mark this op as a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + XlaOpRegistrationBuilder& IsMetadataOp(); std::unique_ptr Build( XlaOpRegistry::Factory factory); private: - XlaOpRegistrationBuilder(StringPiece name); + XlaOpRegistrationBuilder(absl::string_view name); std::unique_ptr registration_; }; @@ -285,7 +301,7 @@ class XlaOpRegistrar { class XlaBackendRegistrar { public: - XlaBackendRegistrar(StringPiece name, gtl::ArraySlice types, + XlaBackendRegistrar(absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter = nullptr); }; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 7928fa034725206a752cbfe086d01f15cd235df9..56c2e01055665954b99ea635e56666fbd8b96026 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -43,7 +43,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, for (const string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); } } @@ -135,7 +135,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, gradient_value, tensor_array_size_, /*tensor_array_gradients=*/{})); } diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index fdf13bb18c2567d2994612d15119ae87cbfa9137..cc7390c6e60375b4c31c38f9f7dee25730f8f51e 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -113,6 +113,7 @@ cc_library( ":statusor", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -161,7 +162,6 @@ cc_library( "iterator_util.h", "map_util.h", "overflow_util.h", - "ptr_util.h", "util.h", ], visibility = ["//visibility:public"], @@ -172,7 +172,11 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -189,6 +193,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", ], ) @@ -210,6 +215,7 @@ tf_cc_test( ":test", ":util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -236,10 +242,14 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -256,6 +266,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -297,6 +308,10 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -315,6 +330,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -335,6 +352,9 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -353,6 +373,8 @@ cc_library( ":literal_util", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -364,6 +386,8 @@ cc_library( deps = [ ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -373,8 +397,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -385,6 +409,8 @@ cc_library( ":status", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -405,8 +431,9 @@ cc_library( deps = [ ":array", ":types", - ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -451,6 +478,8 @@ cc_library( ":array2d", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -462,6 +491,7 @@ tf_cc_test( ":test", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/types:span", ], ) @@ -489,6 +519,8 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", ], ) @@ -503,6 +535,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -521,6 +554,8 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -551,6 +586,8 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -576,10 +613,12 @@ cc_library( deps = [ ":shape_util", ":status_macros", - ":util", ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -593,6 +632,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -619,6 +659,8 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -642,6 +684,8 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -660,6 +704,7 @@ tf_cc_test( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -671,7 +716,8 @@ cc_library( ":array2d", ":shape_util", ":xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ea75ad32d5df7bbadd37e89de6144b264ab6d5d1..58cc1575858201b4508d7340cb47e59c4f4c5783 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -27,12 +27,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -97,12 +97,11 @@ class Array { using value_type = T; // Creates a new array with the specified dimensions. - explicit Array(tensorflow::gtl::ArraySlice sizes) - : Array(sizes, T()) {} + explicit Array(absl::Span sizes) : Array(sizes, T()) {} // Creates a new array with the specified dimensions and specified value for // every cell. - Array(tensorflow::gtl::ArraySlice sizes, T value) + Array(absl::Span sizes, T value) : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { Fill(value); } @@ -301,7 +300,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. - void Each(std::function, T*)> f) { + void Each(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, &values_[i]); @@ -309,8 +308,7 @@ class Array { } // Invokes a callback with the (indices, value) for each cell in the array. - void Each( - std::function, T)> f) const { + void Each(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, values_[i]); @@ -320,8 +318,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T*)> f) { + Status EachStatus(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, &values_[i]); @@ -335,8 +332,7 @@ class Array { // Invokes a callback with the (indices, value) for each cell in the array. // If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T)> f) const { + Status EachStatus(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, values_[i]); @@ -377,13 +373,13 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - const T& operator()(tensorflow::gtl::ArraySlice indexes) const { + const T& operator()(absl::Span indexes) const { return values_[calculate_index(indexes)]; } // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - T& operator()(tensorflow::gtl::ArraySlice indexes) { + T& operator()(absl::Span indexes) { return values_[calculate_index(indexes)]; } @@ -409,7 +405,7 @@ class Array { // Returns the total number of elements in the array. int64 num_elements() const { - return std::accumulate(sizes_.begin(), sizes_.end(), 1, + return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, std::multiplies()); } @@ -438,8 +434,8 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } // Performs the equivalent of a slice operation on this array. - Array Slice(tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits) const { + Array Slice(absl::Span starts, + absl::Span limits) const { CHECK_EQ(starts.size(), num_dimensions()); CHECK_EQ(limits.size(), num_dimensions()); @@ -464,7 +460,7 @@ class Array { // Performs the equivalent of a DynamicUpdateSlice in-place on this array. void UpdateSlice(const Array& from, - tensorflow::gtl::ArraySlice start_indices) { + absl::Span start_indices) { CHECK_EQ(from.num_dimensions(), num_dimensions()); std::vector limit_indices; std::transform(start_indices.begin(), start_indices.end(), @@ -484,7 +480,7 @@ class Array { // Performs an in-place reshape, modifying the dimensions but not the // underlying data. - void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + void Reshape(absl::Span new_dimensions) { int64 old_num_elements = num_elements(); sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); CHECK_EQ(num_elements(), old_num_elements); @@ -507,9 +503,7 @@ class Array { } } - pieces.push_back( - tensorflow::strings::AlphaNum(values_[calculate_index(index)]) - .data()); + pieces.push_back(absl::StrCat(values_[calculate_index(index)])); // Emit comma if it isn't the last element if (index.back() != sizes_.back() - 1) { @@ -527,7 +521,7 @@ class Array { } } } while (next_index(&index)); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } private: diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index a17e81f44832f272fd93dce9f854042b4a84fde4..782c966b4c57672d137569a318fb20ace14d493b 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -24,12 +24,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -101,7 +100,7 @@ class Array2D : public Array { template std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 n1, int64 n2) { - auto array = MakeUnique>(n1, n2); + auto array = absl::make_unique>(n1, n2); int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index a75fffc605aa0df3e1e2eeb6d3129718cbbba0e4..e23d317baf9aca7b3705a93d6be952fb9a17762b 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -26,13 +26,11 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 927733ea1eab43feff643c35535cc6d9ea59ba5a..918872a7a03a022c72d22dfb8f0da9e9d3820e41 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { @@ -27,8 +27,7 @@ namespace { // Given an Array4D and a 4-tuple index, computes the linear index into the // array idx represents. template -int64 Array4DLinearIndex(const Array4D& arr, - tensorflow::gtl::ArraySlice idx) { +int64 Array4DLinearIndex(const Array4D& arr, absl::Span idx) { EXPECT_EQ(4, idx.size()); return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() + idx[0] * arr.n2() * arr.n3() * arr.n4()); @@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) { EXPECT_EQ(fullof7.n3(), 4); EXPECT_EQ(fullof7.n4(), 5); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); } TEST(Array4dTest, ContainerCtor) { @@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) { EXPECT_EQ(arr.n3(), 4); EXPECT_EQ(arr.n4(), 5); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, int* cell) { + arr.Each([&arr](absl::Span idx, int* cell) { EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx)); }); } @@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) { TEST(Array4dTest, Fill) { Array4D fullof7(2, 3, 4, 5, 7); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); fullof7.Fill(11); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 11); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 11); }); } TEST(Array4dTest, FillWithMultiples) { Array4D arr(2, 3, 4, 5); arr.FillWithMultiples(2.0f); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, float* cell) { + arr.Each([&arr](absl::Span idx, float* cell) { EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx)); }); } diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index e8356c9832d34135f5ffb1a5c7a9d6db6db3a051..2d0ac98bd4ee27004295c4189cb190bb2c9739c9 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -163,7 +163,7 @@ TEST(ArrayTest, Each) { arr.FillWithMultiples(1); int64 each_count = 0, each_sum = 0; - arr.Each([&](tensorflow::gtl::ArraySlice idx, int cell) { + arr.Each([&](absl::Span idx, int cell) { int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2]; EXPECT_EQ(lin_idx, cell); each_count++; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index ad3fcee05b80181369bfdf3cdcdb5452ec9e7e89..f825f67b447514a416f3a49ac8aad9dcf505f5a7 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -71,12 +72,14 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -90,6 +93,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], ) @@ -104,7 +110,6 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", @@ -115,8 +120,9 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", "@llvm//:support", ], ) @@ -130,11 +136,11 @@ cc_library( ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:support", ], ) @@ -159,6 +165,7 @@ cc_library( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -186,6 +193,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", + "@com_google_absl//absl/memory", ], ) @@ -211,6 +219,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d0ce5e8a6afa262d4cffdfe8431aab570ffd28df..5dde5b432f136c16d4e3795569499ee5de709763 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {} Client::~Client() = default; -StatusOr> Client::Transfer( - const GlobalData& data, const Shape* shape_with_layout) { +StatusOr Client::Transfer(const GlobalData& data, + const Shape* shape_with_layout) { TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { @@ -89,7 +89,7 @@ StatusOr> Client::TransferToServer( "TransferToServer request"); } - return MakeUnique(stub_, response.data()); + return absl::make_unique(stub_, response.data()); } Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, @@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, return Status::OK(); } -StatusOr> Client::TransferFromOutfeed( +StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id, const DeviceHandle* device_handle) { TransferFromOutfeedRequest request; @@ -162,9 +162,8 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, +StatusOr Client::ExecuteAndTransfer( + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { TF_ASSIGN_OR_RETURN( @@ -178,8 +177,8 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr> Client::ComputeConstant( - const XlaComputation& computation, const Layout* output_layout) const { +StatusOr Client::ComputeConstant(const XlaComputation& computation, + const Layout* output_layout) const { ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { @@ -212,8 +211,7 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { } StatusOr> Client::Execute( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { ExecuteGraphRequest request; @@ -248,11 +246,11 @@ StatusOr> Client::Execute( } } - return MakeUnique(stub_, response.output()); + return absl::make_unique(stub_, response.output()); } StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { ExecuteGraphParallelRequest request; for (const XlaComputationInstance& computation : computations) { @@ -278,7 +276,7 @@ StatusOr>> Client::ExecuteParallel( std::vector> outputs; for (size_t i = 0; i < computations.size(); ++i) { outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); + absl::make_unique(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } @@ -340,7 +338,7 @@ StatusOr>> Client::DeconstructTuple( std::vector> handles; for (auto& handle : response.element_handles()) { - handles.push_back(MakeUnique(stub_, handle)); + handles.push_back(absl::make_unique(stub_, handle)); } return std::move(handles); } @@ -369,7 +367,7 @@ StatusOr Client::GetComputationStats( StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); - return MakeUnique(result); + return absl::make_unique(result); } StatusOr Client::GetShape(const GlobalData& data) { @@ -400,7 +398,7 @@ StatusOr Client::ExecutionStatsAsString( int64 nanoseconds = profile.compute_time_ns(); int64 cycle_count = profile.compute_cycle_count(); double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( + return absl::StrCat( "[Execution Statistics] flop count: ", computation_stats.flop_count(), ", transcendental count: ", computation_stats.transcendental_count(), ", compute execution time: ", nanoseconds, " nsec", diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index be50cebfcc0e3c19002635dbd280b14048aa0c93..6f4d33c469f1f885cfeef546e3981dc3417ef71f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -53,7 +53,7 @@ class Client { // will be filled with profile data from the execution. StatusOr> Execute( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -82,7 +82,7 @@ class Client { // from each computation. // StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations @@ -96,8 +96,8 @@ class Client { // // If shape_with_layout is not nullptr, it points to a shape whose layout will // be the layout of the returned literal. - StatusOr> Transfer( - const GlobalData& data, const Shape* shape_with_layout = nullptr); + StatusOr Transfer(const GlobalData& data, + const Shape* shape_with_layout = nullptr); // Transfer the given literal to the server. This allocates memory on the // device and copies the literal's contents over. Returns a global data handle @@ -122,7 +122,7 @@ class Client { // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - StatusOr> TransferFromOutfeed( + StatusOr TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); @@ -132,9 +132,9 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -153,7 +153,7 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - StatusOr> ComputeConstant( + StatusOr ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 803a9e40094391ba47ed27713f4538caf875c4f6..27b7fa7b29206affa9f9c2e4becd9e4ea66484ab 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -94,10 +95,10 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); - instance->client = MakeUnique(instance->service.get()); + instance->client = absl::make_unique(instance->service.get()); LocalClient* cl = instance->client.get(); client_library.local_instances_.insert( @@ -134,10 +135,11 @@ ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { return it->second->client.get(); } - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, CompileOnlyService::NewService(platform)); - instance->client = MakeUnique(instance->service.get()); + instance->client = + absl::make_unique(instance->service.get()); CompileOnlyClient* cl = instance->client.get(); client_library.compile_only_instances_.insert( diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 5c9abad4c3126be5e45e96c770c0679fe8606788..a6c58cb17571b63cd0f45d0d95376a02bc4a72e2 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -15,15 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector service_instances; @@ -41,7 +41,7 @@ CompileOnlyClient::CompileAheadOfTime( metadata); } -int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { +int64 CompileOnlyClient::PointerSizeForTriple(absl::string_view triple) { llvm::Triple llvm_triple( llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); if (llvm_triple.isArch64Bit()) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index a551edeab0943ec5213c5cb035644c02c3cf54d7..9e3ed23734941d98d622c38028cd44d48d3e620a 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -52,12 +52,12 @@ class CompileOnlyClient : public Client { // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + static int64 PointerSizeForTriple(absl::string_view triple); private: CompileOnlyService* compiler_service_; diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 7dee41f6a05025ec196b78e54015e8e71777031f..0f1745366b7c33e573aff2e66d85431b01488c49 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -59,10 +59,10 @@ string ExecutableBuildOptions::ToString() const { if (generate_hlo_graph_.has_value()) { generate_hlo_graph = generate_hlo_graph_.value(); } - return tensorflow::strings::Printf( + return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); + device_ordinal_, result_layout, generate_hlo_graph); } ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( @@ -71,41 +71,41 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( return *this; } -const tensorflow::gtl::optional& -ExecutableBuildOptions::generate_hlo_graph() const { +const absl::optional& ExecutableBuildOptions::generate_hlo_graph() + const { return generate_hlo_graph_; } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_optimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_optimized_hlo_proto_to_ = string(dirpath); return *this; } -const tensorflow::gtl::optional& +const absl::optional& ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_unoptimized_hlo_proto_to_ = string(dirpath); return *this; } -const tensorflow::gtl::optional& +const absl::optional& ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { return dump_unoptimized_hlo_proto_to_; } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_per_pass_hlo_proto_to_ = string(dirpath); return *this; } -const tensorflow::gtl::optional& +const absl::optional& ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { return dump_per_pass_hlo_proto_to_; } @@ -115,7 +115,7 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { return *this; } -tensorflow::gtl::optional ExecutableBuildOptions::hlo_profile() const { +absl::optional ExecutableBuildOptions::hlo_profile() const { return hlo_profile_; } diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 9dc9be4423564fb967b247c2d1df31099cb80237..93334db88bc24f2ffbf3c7a57ee45ef238286739 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -57,37 +57,36 @@ class ExecutableBuildOptions { // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). ExecutableBuildOptions& set_generate_hlo_graph(string regex); - const tensorflow::gtl::optional& generate_hlo_graph() const; + const absl::optional& generate_hlo_graph() const; // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath); - const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + absl::string_view dirpath); + const absl::optional& dump_optimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath); - const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() - const; + absl::string_view dirpath); + const absl::optional& dump_unoptimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath); - const tensorflow::gtl::optional& dump_per_pass_hlo_proto_to() const; + absl::string_view dirpath); + const absl::optional& dump_per_pass_hlo_proto_to() const; // If true, specifies that we should record an HLO profile during execution // and log it after execution (as in DebugOptions). If nullopt the default is // used. ExecutableBuildOptions& set_hlo_profile(bool enabled); - tensorflow::gtl::optional hlo_profile() const; + absl::optional hlo_profile() const; - void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + void add_disabled_hlo_pass(absl::string_view pass_name) { disabled_hlo_passes_.push_back(std::string(pass_name)); } - const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + const absl::Span disabled_hlo_passes() const { return disabled_hlo_passes_; } @@ -96,14 +95,14 @@ class ExecutableBuildOptions { string ToString() const; private: - tensorflow::gtl::optional hlo_profile_; + absl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - tensorflow::gtl::optional generate_hlo_graph_; - tensorflow::gtl::optional dump_optimized_hlo_proto_to_; - tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; - tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; + absl::optional generate_hlo_graph_; + absl::optional dump_optimized_hlo_proto_to_; + absl::optional dump_unoptimized_hlo_proto_to_; + absl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; std::vector disabled_hlo_passes_; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 39d5582d19dbb9942ae87e1962fc9fa713bcdd50..a18c94c4e695a6cdcb9dcc60b64b617cecd276d8 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -31,7 +31,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -64,6 +64,17 @@ xla_test( ], ) +cc_library( + name = "conv_grad_size_util", + srcs = ["conv_grad_size_util.cc"], + hdrs = ["conv_grad_size_util.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/core:lib", + ], +) + cc_library( name = "math", srcs = ["math.cc"], @@ -102,7 +113,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -121,6 +132,31 @@ xla_test( ], ) +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + ":conv_grad_size_util", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/container:inlined_vector", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", + ], +) + cc_library( name = "prng", srcs = ["prng.cc"], @@ -144,7 +180,7 @@ cc_library( ":numeric", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", ], ) @@ -161,7 +197,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -185,5 +221,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 9225b1acd69c214d6f08a45372a8082ed789c18c..e86c10f030f3990d67e5a6638100640f73c82307 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -39,7 +39,7 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, b = builder->CreateSubBuilder(name); } else { b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + absl::StrCat(name, "_", PrimitiveType_Name(type))); } const Shape scalar = ShapeUtil::MakeShape(type, {}); diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 031d62e4ffef188082303a28866bbc72a154e9b1..1ada7b4a964ccf7ca400b937abbe425bef083468 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -56,7 +56,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { std::numeric_limits::epsilon()); default: return builder->ReportError(InvalidArgument( - "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178..81624614c1e3599dfe116eb61d9e2edcd5230684 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -37,13 +37,13 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { primitive_util::IsComplexType(type))) { return builder->ReportError(InvalidArgument( "Invalid cast from floating point type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } if (std::is_same::value && !primitive_util::IsComplexType(type)) { return builder->ReportError(InvalidArgument( "Invalid cast from complex type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } switch (type) { case F16: @@ -71,7 +71,7 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { default: return builder->ReportError( InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4c50a5491803bc62d2de758177f8f5d050f441d --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +StatusOr GetWindowedOutputSize( + int64 input_size, int64 filter_size, int64 dilation_rate, int64 stride, + Padding padding_type) { + if (stride <= 0) { + return tensorflow::errors::InvalidArgument("Stride must be > 0, but got ", + stride); + } + if (dilation_rate < 1) { + return tensorflow::errors::InvalidArgument( + "Dilation rate must be >= 1, but got ", dilation_rate); + } + + int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; + SpatialDimensionOutputSizeAndPadding dim; + switch (padding_type) { + case Padding::kValid: + dim.output_size = (input_size - effective_filter_size + stride) / stride; + dim.pad_before = dim.pad_after = 0; + break; + case Padding::kSame: + dim.output_size = (input_size + stride - 1) / stride; + const int64 padding_needed = + std::max(int64{0}, (dim.output_size - 1) * stride + + effective_filter_size - input_size); + // For odd values of total padding, add more padding on the "after" side + // of the given dimension. + dim.pad_before = padding_needed / 2; + dim.pad_after = padding_needed - dim.pad_before; + break; + } + if (dim.output_size < 0) { + return tensorflow::errors::InvalidArgument( + "Computed output size would be negative: ", dim.output_size, + " [input_size: ", input_size, + ", effective_filter_size: ", effective_filter_size, + ", stride: ", stride, "]"); + } + return dim; +} + +} // namespace + +StatusOr +ConvGradExtractAndVerifyDimension(int64 input_size, int64 filter_size, + int64 output_size, int64 dilation, + int64 stride, Padding padding) { + TF_ASSIGN_OR_RETURN(SpatialDimensionOutputSizeAndPadding output_dim, + GetWindowedOutputSize(input_size, filter_size, dilation, + stride, padding)); + if (output_size != output_dim.output_size) { + return tensorflow::errors::InvalidArgument( + "Size of out_backprop doesn't match computed: ", "actual = ", + output_size, ", computed = ", output_dim.output_size, + " input: ", input_size, " filter: ", filter_size, + " output: ", output_size, " stride: ", stride, " dilation: ", dilation); + } + + SpatialDimensionOutputSizeAndPadding dim; + int64 effective_filter_size = (filter_size - 1) * dilation + 1; + dim.output_size = (output_dim.output_size - 1) * stride + 1; + const auto padded_out_size = input_size + effective_filter_size - 1; + dim.pad_before = effective_filter_size - 1 - output_dim.pad_before; + dim.pad_after = padded_out_size - dim.output_size - dim.pad_before; + VLOG(2) << "expanded_out = " << dim.output_size + << ", effective_filter_size = " << effective_filter_size + << ", padded_out = " << padded_out_size + << ", pad_before = " << dim.pad_before + << ", pad_after = " << dim.pad_after << ", dilation = " << dilation + << ", strides = " << stride; + return dim; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0ad01728e6e828240b9ac4b948777e5d970d09e0 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ + +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Information about a single spatial dimension for a convolution gradients and +// windowed operations. +struct SpatialDimensionOutputSizeAndPadding { + // Effective size of the operation output (potentially expanded). + int64 output_size; + // Number of padding elements to be added before/after this dimension of + // the input when computing the input gradient. + int64 pad_before; + int64 pad_after; +}; + +// Verifies that the dimensions all match, and computes the size and padding of +// a spatial dimension for convolution gradient operations. +StatusOr +ConvGradExtractAndVerifyDimension(int64 input_size, int64 filter_size, + int64 output_size, int64 dilation, + int64 stride, Padding padding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 0221de7672c7b7c02b1f8b9c7ff4f92151e567c6..d3d7edb42a38595bbf9fdb36e0dd946ae5df51f9 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -69,8 +69,7 @@ std::array kErfUCoefficient = { // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients) { +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { XlaOp poly = ScalarLike(x, 0.0); for (float c : coefficients) { poly = poly * x + ScalarLike(x, c); @@ -207,7 +206,11 @@ XlaOp Lgamma(XlaOp input) { XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y; + // If z = a + 0j, the analytic continuation of log reduces to taking the + // absolute value of the real part. + // Re(log(z)) = Re(log|z| + arg(z)j) + // = log|a| + XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; XlaOp result = Select(need_to_reflect, reflection, log_y); return result; } diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 13db2325569cf2e25e3ff1200adf4b2544dc2f73..a6cafd42077367bf23ffa1f45eab31c01dc31b16 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand); // Evaluates a polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients); +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc index 1c91237ae1574f92cda78c9bddc6f4ac1d68f47c..377654220b5df4487e9e194361473d54ff46a54e 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -16,61 +16,13 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { -namespace { - -template -XlaOp MakeIota(XlaBuilder* builder, int64 size) { - std::vector values(size); - for (int64 i = 0; i < size; ++i) { - values[i] = static_cast(i); - } - return ConstantR1(builder, values); -} - -} // namespace - -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { - switch (type) { - case S8: - return MakeIota(builder, size); - case S16: - return MakeIota(builder, size); - case S32: - return MakeIota(builder, size); - case S64: - return MakeIota(builder, size); - case U8: - return MakeIota(builder, size); - case U16: - return MakeIota(builder, size); - case U32: - return MakeIota(builder, size); - case U64: - return MakeIota(builder, size); - case BF16: - return MakeIota(builder, size); - case F16: - return MakeIota(builder, size); - case F32: - return MakeIota(builder, size); - case F64: - return MakeIota(builder, size); - case C64: - return MakeIota(builder, size); - default: - return builder->ReportError( - InvalidArgument("Unimplemented type for Iota: %s.", - PrimitiveType_Name(type).c_str())); - } -} - XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n) { auto a = Iota(builder, type, m); @@ -87,8 +39,8 @@ XlaOp GetMatrixDiagonal(XlaOp x) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims( - AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); @@ -114,8 +66,8 @@ XlaOp Triangle(XlaOp x, bool lower) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims( - AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); xla::XlaOp indicator; diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc index 8a96ec68d2dca8485215258b1f6731b934e6f2a8..7d6aedd49462bd4f075f90d0b0f85c40f1191aa1 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase { void TestMatrixDiagonal(); }; -// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to -// xla::Iota. This test is already implemented for xla::IotaGen in -// xla/tests/iota_test.cc. -XLA_TEST_F(NumericTest, Iota) { - XlaBuilder builder(TestName()); - Iota(&builder, S32, 10); - - ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); -} - XLA_TEST_F(NumericTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..1979c867a4c3be438f8b997c566799fe84b43053 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" + +namespace xla { + +namespace { + +// Common computation shared between AvgPool and AvgPoolGrad. Divide each +// element of an image by the count of elements that contributed to that +// element during pooling. +XlaOp AvgPoolDivideByCountWithGeneralPadding( + XlaOp sums, PrimitiveType dtype, absl::Span input_shape, + absl::Span> spatial_padding, + absl::Span ksize, absl::Span stride, + const TensorFormat& data_format) { + // The padding shouldn't be included in the counts. We use another + // ReduceWindow to find the right counts. + const int num_spatial_dims = spatial_padding.size(); + + std::vector input_dim_sizes(num_spatial_dims); + std::vector window_dims(num_spatial_dims); + std::vector window_ksize(num_spatial_dims); + std::vector window_stride(num_spatial_dims); + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_dim_sizes[i] = input_shape[dim]; + window_dims[i] = dim; + window_ksize[i] = ksize[dim]; + window_stride[i] = stride[dim]; + } + + XlaBuilder* b = sums.builder(); + // Build a matrix of all 1s, with the same width/height as the input. + auto ones = Broadcast(One(b, dtype), input_dim_sizes); + PaddingConfig padding_config; + for (int i = 0; i < num_spatial_dims; ++i) { + auto dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(spatial_padding[i].first); + dims->set_edge_padding_high(spatial_padding[i].second); + } + auto zero = Zero(b, dtype); + auto padded_ones = Pad(ones, zero, padding_config); + + // Perform a ReduceWindow with the same window size, strides, and padding + // to count the number of contributions to each result element. + auto counts = + ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b), + window_ksize, window_stride, Padding::kValid); + + return Div(sums, counts, window_dims); +} + +// Sums all elements in the window specified by 'kernel_size' and 'stride'. +XlaOp ComputeSums(XlaOp operand, XlaOp init_value, + absl::Span kernel_size, + absl::Span stride, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); + PrimitiveType accumulation_type = init_shape.element_type(); + auto add_computation = CreateScalarAddComputation(accumulation_type, b); + return ReduceWindow(operand, init_value, add_computation, kernel_size, + stride, Padding::kValid); + }); +} + +// Creates a padding configuration out of spatial padding values. +PaddingConfig MakeSpatialPaddingConfig( + absl::Span> spatial_padding, + int num_spatial_dims, absl::Span stride, + const TensorFormat& data_format) { + PaddingConfig padding_config; + for (int i = 0; i < 2 + num_spatial_dims; ++i) { + padding_config.add_dimensions(); + } + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + auto padding_dimension = padding_config.mutable_dimensions(dim); + padding_dimension->set_edge_padding_low(spatial_padding[i].first); + padding_dimension->set_edge_padding_high(spatial_padding[i].second); + } + return padding_config; +} + +XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span input_size, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + PrimitiveType dtype, const TensorFormat& data_format, + bool counts_include_padding) { + if (counts_include_padding) { + // If counts include padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to get + // the average. + int64 window_size = + std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1, + [](int64 a, int64 b) { return a * b; }); + auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size); + + return pooled / divisor; + } else { + return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size, + padding, window_dimensions, + window_strides, data_format); + } +} + +} // namespace + +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto max_computation = CreateScalarMaxComputation(dtype, b); + auto init_value = MinValue(b, dtype); + return ReduceWindow(operand, init_value, max_computation, kernel_size, + stride, padding); + }); +} + +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto init_value = Zero(b, dtype); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + const int num_dims = kernel_size.size(); + const int num_spatial_dims = num_dims - 2; + auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims, + stride, data_format); + auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); + auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, + data_format); + return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride, + padding, dtype, data_format, + counts_include_padding); + }); +} + +std::vector> MakeSpatialPadding( + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + std::vector input_spatial_dimensions; + std::vector kernel_size_spatial_dimensions; + std::vector stride_spatial_dimensions; + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_spatial_dimensions.push_back(input_size[dim]); + kernel_size_spatial_dimensions.push_back(kernel_size[dim]); + stride_spatial_dimensions.push_back(stride[dim]); + } + return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions, + stride_spatial_dimensions, padding); +} + +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = out_backprop.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + const int num_dims = kernel_size.size(); + + if (gradients_size.size() != num_dims) { + return tensorflow::errors::InvalidArgument("gradients must be ", num_dims, + "-dimensional"); + } + + TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape, + b->GetShape(out_backprop)); + if (out_backprop_xla_shape.dimensions().size() != num_dims) { + return tensorflow::errors::InvalidArgument("out_backprop must be ", + num_dims, "-dimensional"); + } + + // We can think of average-pooling as: + // * a convolution with a kernel consisting entirely of 1s, where the + // input feature and output feature are equal, and 0s everywhere else. + // * followed by dividing by the counts. + // + // This then gives us an algorithm to build the gradient: + // * divide out_backprop by the counts, followed by + // * Conv2DBackpropInput specialized for that kernel, which simplifies to + // a Pad and a ReduceWindow. + // + // For an explanation of backpropagation for convolution, see the comments + // in third_party/tensorflow/core/kernels/conv_grad_ops.h + + // TF filter shape is [ H, W, ..., inC, outC ] + + // The input gradients are computed by a convolution of the output gradients + // and the filter, with some appropriate padding. See the comment at the top + // of conv_grad_ops.h for details. + PrimitiveType dtype = out_backprop_xla_shape.element_type(); + auto out_backprop_div = AvgPoolDivideByCount( + out_backprop, gradients_size, kernel_size, stride, spatial_padding, + dtype, data_format, counts_include_padding); + + // Pad the gradients in the spatial dimensions. We use the same padding + // as Conv2DBackpropInput. + PaddingConfig padding_config = MakeNoPaddingConfig(num_dims); + std::vector padded_gradients_size(gradients_size.begin(), + gradients_size.end()); + // First, pad the output gradients the same way as the input. The additional + // padding will be removed as a last step before returning the input + // gradients. + const int num_spatial_dims = num_dims - 2; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + padded_gradients_size[dim] += + (spatial_padding[i].first + spatial_padding[i].second); + } + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + TF_ASSIGN_OR_RETURN( + SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim, + ConvGradExtractAndVerifyDimension( + /*input_size=*/padded_gradients_size[dim], + /*filter_size=*/kernel_size[dim], + /*output_size=*/out_backprop_xla_shape.dimensions(dim), + /*dilation=*/1, + /*stride=*/stride[dim], /*padding=*/Padding::kValid)); + auto* padding = padding_config.mutable_dimensions(dim); + padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before); + padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after); + padding->set_interior_padding(stride[dim] - 1); + } + + auto zero = Zero(b, dtype); + auto padded_gradients = Pad(out_backprop_div, zero, padding_config); + + // in_backprop = padded_gradients ones + std::vector ones(num_dims, 1LL); + auto in_backprop = + ReduceWindow(padded_gradients, Zero(b, dtype), + CreateScalarAddComputation(dtype, b), kernel_size, + /*window_strides=*/ones, Padding::kValid); + // The input padding doesn't contribute to the gradient, remove it. + std::vector> neg_spatial_padding; + neg_spatial_padding.reserve(spatial_padding.size()); + for (const std::pair& spatial_padding_dim : spatial_padding) { + neg_spatial_padding.emplace_back(-spatial_padding_dim.first, + -spatial_padding_dim.second); + } + auto remove_padding_config = MakeSpatialPaddingConfig( + neg_spatial_padding, num_spatial_dims, stride, data_format); + return Pad(in_backprop, zero, remove_padding_config); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..5c0054857d072dc7f36e259a29b9b24fd70796ac --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + absl::Span spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + absl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, + const TensorFormat& data_format, + const bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool gradient. +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30adb9b1ad7fa03b40ce3802a2172680b60a9ad7 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -0,0 +1,290 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +TensorFormat MakeNCHWFormat(int num_spatial_dims) { + absl::InlinedVector spatial_dimensions; + for (int i = 0; i < num_spatial_dims; ++i) { + spatial_dimensions.push_back(i + 2); + } + return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1, + /*spatial_dimensions=*/spatial_dimensions); +} + +std::vector> MakeGeneralPadding( + XlaOp input, absl::Span kernel_size, + absl::Span stride, Padding padding, + const xla::TensorFormat& data_format) { + XlaBuilder* b = input.builder(); + Shape operand_shape = b->GetShape(input).ValueOrDie(); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + return MakeSpatialPadding(input_size, kernel_size, stride, padding, + data_format); +} + +// Add singleton batch and feature dimensions to spatial dimensions, according +// to 'data_format' specification. +std::vector ExpandWithBatchAndFeatureDimensions( + absl::Span spatial_dim_sizes, + const xla::TensorFormat& data_format) { + const int num_spatial_dims = spatial_dim_sizes.size(); + std::vector tensor_sizes(num_spatial_dims + 2, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + tensor_sizes[dim] = spatial_dim_sizes[i]; + } + return tensor_sizes; +} + +class PoolingTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(PoolingTest, MaxPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kValid, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}}, + {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/true); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, + {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = kernel_size; + AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradNoPadding) { + XlaBuilder builder(TestName()); + for (bool counts_include_padding : {false, true}) { + XlaOp out_backprop = ConstantR4FromArray4D(&builder, {{{{1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, + {{0, 0}, {0, 0}}, MakeNCHWFormat(2), + /*counts_include_padding=*/counts_include_padding); + // Without padding, counts_include_padding makes no difference. + ComputeAndCompareR4( + &builder, {{{{0.25, 0.25, 0.}, {0.25, 0.25, 0.}, {0., 0., 0.}}}}, {}, + error_spec_); + } +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) { + XlaBuilder builder(TestName()); + for (bool counts_include_padding : {false, true}) { + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1.}, {1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, + {{0, 0}, {0, 0}}, MakeNCHWFormat(2), + /*counts_include_padding=*/counts_include_padding); + // Without padding, counts_include_padding makes no difference. + ComputeAndCompareR4( + &builder, {{{{0.25, 0.5, 0.25}, {0.5, 1., 0.5}, {0.25, 0.5, 0.25}}}}, + {}, error_spec_); + } +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1.}, {1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), + /*counts_include_padding=*/true); + ComputeAndCompareR4( + &builder, + {{{{0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1.}, {1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), false); + ComputeAndCompareR4( + &builder, {{{{1., 0.5, 0.5}, {0.5, 0.25, 0.25}, {0.5, 0.25, 0.25}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), true); + ComputeAndCompareR4(&builder, + {{{{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DGradWithPaddingCountWithStrideNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), false); + ComputeAndCompareR4( + &builder, {{{{2.25, 1.5, 2.25}, {1.5, 1., 1.5}, {2.25, 1.5, 2.25}}}}, {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h index 404b4783c3878ca0fab811fa8c3d02686af44316..b9dfafdd6f957ae050e0f5dbd076d5288235b490 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.h +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index b6eee762a5f002e00fd6118d91f25343e22f13d3..fef98c9923096e21a755c6d730de2c7c10852b2d 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 081fec7ad92958aa285e4be41394d7b1876e0815..ff0ec76a7f9b62fce0f14beae688cb0dd74847a1 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -61,8 +61,7 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - XlaBuilder b( - tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); + XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); @@ -77,7 +76,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { if (DataSizeOfShape(shape) < (1LL << 20)) { - StatusOr> literal_status = MakeFakeLiteral(shape); + StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via // an on-device computation. @@ -85,7 +84,7 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, tensorflow::error::UNIMPLEMENTED); return MakeFakeDataViaDeviceOrDie(shape, client); } - return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie(); + return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. @@ -98,13 +97,11 @@ std::vector> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // Create and run a program which produces a tuple with one element per - // parameter, then return the tuple's constituent buffers. - std::vector param_shapes(program_shape.parameters().begin(), - program_shape.parameters().end()); - auto fake_input_tuple = - MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); - return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); + std::vector> results; + for (const Shape& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(shape, client)); + } + return results; } } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 8a6c5fb9a750cd74d47a66269843ec252ffbbbd4..f96b6c9c261a9686fb647e3da0dcc933cd1f70df 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" @@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, } Status LocalExecutable::ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -59,7 +59,7 @@ Status LocalExecutable::ValidateExecutionOptions( // Check argument number, shapes, and layouts. if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", + "invalid number of arguments for computation: expected %d, got %u", computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); + ShapeUtil::HumanString( + computation_layout.parameter_layout(i).shape()), + ShapeUtil::HumanString(arguments[i]->on_host_shape())); } } @@ -88,8 +88,7 @@ Status LocalExecutable::ValidateExecutionOptions( if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", - stream_platform->Name().c_str(), - backend_->platform()->Name().c_str()); + stream_platform->Name(), backend_->platform()->Name()); } // Cannot specify device_ordinal with a stream. The stream determines these @@ -120,10 +119,10 @@ Status LocalExecutable::ValidateExecutionOptions( return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal()).c_str(), - build_executor->GetDeviceDescription().name().c_str(), - backend_->device_name(run_device_ordinal).c_str(), - run_executor->GetDeviceDescription().name().c_str()); + backend_->device_name(build_device_ordinal()), + build_executor->GetDeviceDescription().name(), + backend_->device_name(run_device_ordinal), + run_executor->GetDeviceDescription().name()); } if (!run_options.allocator()) { @@ -133,15 +132,15 @@ Status LocalExecutable::ValidateExecutionOptions( if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - run_options.allocator()->platform()->Name().c_str(), - backend.platform()->Name().c_str()); + run_options.allocator()->platform()->Name(), + backend.platform()->Name()); } return Status::OK(); } StatusOr LocalExecutable::Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); @@ -178,7 +177,7 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments) { + const absl::Span arguments) { executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); @@ -192,13 +191,12 @@ StatusOr LocalExecutable::ExecuteAndDump( } Status LocalExecutable::RecordArguments( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*argument)); - *hlo_snapshot->add_arguments() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument)); + *hlo_snapshot->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -206,13 +204,12 @@ Status LocalExecutable::RecordArguments( Status LocalExecutable::RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_result(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*result)); - *hlo_snapshot->mutable_result() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result)); + *hlo_snapshot->mutable_result() = literal.ToProto(); return Status::OK(); } -StatusOr> LocalExecutable::LiteralFromShapedBuffer( +StatusOr LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, backend_->BorrowStream(shaped_buffer.device_ordinal())); @@ -246,7 +243,7 @@ Backend* LocalClient::mutable_backend() { StatusOr> LocalClient::Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options) { ExecutableBuildOptions updated_options = options; if (options.device_ordinal() == -1) { @@ -257,9 +254,9 @@ StatusOr> LocalClient::Compile( TF_ASSIGN_OR_RETURN(std::unique_ptr executable, local_service_->CompileExecutable( computation, argument_layouts, updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); + return absl::WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); } StatusOr LocalClient::LiteralToShapedBuffer( @@ -278,7 +275,7 @@ StatusOr LocalClient::LiteralToShapedBuffer( return std::move(scoped_buffer); } -StatusOr> LocalClient::ShapedBufferToLiteral( +StatusOr LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( shaped_buffer.device_ordinal())); @@ -299,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal, literal); } -StatusOr> LocalClient::TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal) { +StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, shape, literal.get())); + executor, shape, &literal)); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ae23809261757c637ab4aec036750c371ac60cdc..feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -40,7 +40,7 @@ class LocalExecutable { // Run the compiled computation with the given arguments and options and // return the result. StatusOr Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options); // Return the options used to build the executable. @@ -63,7 +63,7 @@ class LocalExecutable { // The given ExecutableRunOptions override any values from legacy_flags // (TF_XLA_FLAGS environment variable). Status ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to @@ -73,20 +73,18 @@ class LocalExecutable { // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments); + const absl::Span arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - HloSnapshot* hlo_snapshot); + Status RecordArguments(const absl::Span arguments, + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. - StatusOr> LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer); + StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by @@ -120,7 +118,7 @@ class LocalClient : public Client { // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options); // Copy the literal data to the device with the given ordinal and return as a @@ -133,8 +131,7 @@ class LocalClient : public Client { // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. - StatusOr> ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. @@ -152,8 +149,8 @@ class LocalClient : public Client { // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferFromOutfeed. - StatusOr> TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal); + StatusOr TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f..992b13139c480900e7b983825be61ce88f14e11b 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -23,16 +23,15 @@ limitations under the License. namespace xla { -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides) { +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides) { bool ok = input_dimensions.size() == window_dimensions.size() && input_dimensions.size() == window_strides.size(); if (!ok) { return InvalidArgument( - "Want input dimensions size %zu = window dimensions size %zu = window " - "strides size %zu", + "Want input dimensions size %u = window dimensions size %u = window " + "strides size %u", input_dimensions.size(), window_dimensions.size(), window_strides.size()); } @@ -40,9 +39,9 @@ Status ValidatePaddingValues( } std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, window_strides)); std::vector> low_high_padding; diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index e23b0b3a90a091bf80973525810793c3eda4a036..5c009bd49e48b158550a32e64b0d63e2840dd1a9 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -41,10 +41,9 @@ enum class Padding { // Validates that the slices are acceptable for determining padding -- this can // be used to check the preconditions of MakePadding below to produce an error // message that can be returned to the user. -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides); +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides); // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. @@ -58,9 +57,9 @@ Status ValidatePaddingValues( // window_dimensions, and strides must match, which is equal to the number // of elements in the result vector. std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/sharding_builder.h b/tensorflow/compiler/xla/client/sharding_builder.h index 34763e54d946690289ff42a7712b980168933eee..59df3a8762c755848982bc8e2590de968ed2adb6 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.h +++ b/tensorflow/compiler/xla/client/sharding_builder.h @@ -56,4 +56,4 @@ OpSharding Tuple(const ShapeTree& shardings); } // namespace sharding_builder } // namespace xla -#endif +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 1cb61f77fb65102efb3b1dd9d77b8bdcbe8d9125..95ff6432a591f87845729b180397e33a85e5e9a5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -21,19 +21,24 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; namespace { @@ -45,21 +50,6 @@ int64 GetUniqueId() { return id; } -// Returns true if an instruction with the given opcode can be the root of the -// computation. -bool CanBeRoot(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAfterAll: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kOutfeed: - case HloOpcode::kTrace: - return false; - default: - return true; - } -} - } // namespace XlaOp operator-(const XlaOp& x) { return Neg(x); } @@ -82,7 +72,7 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { if (!ShapeUtil::ElementIsIntegral(shape)) { return InvalidArgument( "Argument to >> operator does not have an integral type (%s).", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (ShapeUtil::ElementIsSigned(shape)) { return ShiftRightArithmetic(x, y); @@ -100,7 +90,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { } StatusOr> XlaBuilder::GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const { + absl::Span operands) const { std::vector operand_shapes; for (const XlaOp& operand : operands) { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); @@ -142,28 +132,14 @@ XlaOp XlaBuilder::ReportErrorOrReturn( return ReportErrorOrReturn(op_creator()); } -StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { +StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - - TF_RET_CHECK(root_id != nullptr); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto, + LookUpInstructionByHandle(root_id)); ProgramShape program_shape; - // Not all instructions can be roots. Walk backwards from the last added - // instruction until a valid root is found. - int64 index = instructions_.size() - 1; - for (; index >= 0; index--) { - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[index].opcode())); - if (CanBeRoot(opcode)) { - break; - } - } - if (index < 0) { - return FailedPrecondition("no root instruction was found"); - } - *root_id = instructions_[index].id(); - *program_shape.mutable_result() = instructions_[index].shape(); + *program_shape.mutable_result() = root_proto->shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -188,8 +164,15 @@ StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { } StatusOr XlaBuilder::GetProgramShape() const { - int64 root; - return GetProgramShape(&root); + TF_RET_CHECK(!instructions_.empty()); + return GetProgramShape(instructions_.back().id()); +} + +StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return GetProgramShape(root.handle()); } void XlaBuilder::IsConstantVisitor(const int64 op_handle, @@ -199,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, return; } - CHECK(op_handle < instructions_.size() && op_handle >= 0); - - const HloInstructionProto& instr = instructions_[op_handle]; + const HloInstructionProto& instr = + *(LookUpInstructionByHandle(op_handle).ValueOrDie()); const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); switch (opcode) { default: @@ -217,7 +199,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // TODO(b/33009255): Implmement constant folding for cross replica sum. case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kHostCompute: case HloOpcode::kCall: // TODO(b/32495713): We aren't checking the to_apply computation itself, // so we conservatively say that computations containing the Call op @@ -244,8 +225,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { auto build_status = Build(); if (!build_status.ok()) { parent_builder_->ReportError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); + AddStatus(build_status.status(), absl::StrCat("error from: ", name_))); return {}; } return build_status.ConsumeValueOrDie(); @@ -257,17 +237,29 @@ StatusOr XlaBuilder::Build() { first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + return Build(instructions_.back().id()); +} + +StatusOr XlaBuilder::Build(XlaOp root) { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return Build(root.handle()); +} + +StatusOr XlaBuilder::Build(int64 root_id) { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. - { - int64 root_id; - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), - GetProgramShape(&root_id)); - entry.set_root_id(root_id); - } + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + entry.set_root_id(root_id); for (auto& instruction : instructions_) { // Ensures that the instruction names are unique among the whole graph. @@ -291,6 +283,7 @@ StatusOr XlaBuilder::Build() { // Clear data held by this builder. this->instructions_.clear(); + this->handle_to_index_.clear(); this->embedded_.clear(); this->parameter_numbers_.clear(); @@ -299,7 +292,7 @@ StatusOr XlaBuilder::Build() { StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; @@ -360,9 +353,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { }); } -XlaOp XlaBuilder::BinaryOp( - HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -456,12 +448,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } @@ -474,14 +466,27 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.add_dimensions(iota_dimension); + return AddInstruction(std::move(instr), HloOpcode::kIota); + }); +} + +XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) { + return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); +} + XlaOp XlaBuilder::Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -500,7 +505,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { - return InvalidArgument("parameter %lld already registered", + return InvalidArgument("parameter %d already registered", parameter_number); } instr.set_parameter_number(parameter_number); @@ -510,8 +515,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, }); } -XlaOp XlaBuilder::Broadcast( - const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp XlaBuilder::Broadcast(const XlaOp& operand, + absl::Span broadcast_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -536,7 +541,7 @@ XlaOp XlaBuilder::Broadcast( XlaOp XlaBuilder::BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { + const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { return InDimBroadcast(shape, operand, broadcast_dimensions); }); @@ -551,9 +556,9 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { } XlaOp XlaBuilder::Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -588,7 +593,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, } XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -626,15 +631,15 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, }); } -XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); @@ -666,8 +671,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span dimensions, + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, @@ -681,7 +686,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); @@ -691,7 +696,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus @@ -701,8 +706,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, // Out-of-order collapse is not supported. // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { + for (absl::Span::size_type i = 1; i < dimensions.size(); ++i) { if (dimensions[i] - 1 != dimensions[i - 1]) { return InvalidArgument( "Collapsed dimensions are not in consecutive order."); @@ -714,8 +718,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand)); VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); + VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { @@ -726,8 +729,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } } - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; + VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; return Reshape(operand, new_sizes); }); @@ -737,7 +739,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } @@ -755,13 +757,13 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, }); } -XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { +XlaOp XlaBuilder::Tuple(absl::Span elements) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); @@ -776,7 +778,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { if (!ShapeUtil::IsTuple(tuple_shape)) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(tuple_shape).c_str()); + ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = ShapeUtil::GetTupleElementShape(tuple_shape, index); @@ -789,36 +791,37 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -826,12 +829,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config); }); } XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers) { + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -840,6 +844,9 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); } @@ -851,16 +858,14 @@ Status XlaBuilder::VerifyConvolution( return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_dims = ShapeUtil::Rank(lhs_shape); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_spatial_dims = num_dims - 2; @@ -874,7 +879,7 @@ Status XlaBuilder::VerifyConvolution( } for (int i = 0; i < numbers.size(); ++i) { if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + return InvalidArgument("Convolution %s[%d] is out of bounds: %d", field_name, i, numbers.Get(i)); } } @@ -892,25 +897,28 @@ Status XlaBuilder::VerifyConvolution( } XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -937,26 +945,27 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( return ConvGeneral(lhs, rhs, window_strides, MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), - dimension_numbers); + dimension_numbers, feature_group_count, + precision_config); }); } XlaOp XlaBuilder::ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); + dimension_numbers, feature_group_count, + precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -975,12 +984,17 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), - dimension_numbers)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, feature_group_count, + instr.window(), dimension_numbers)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); + + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); @@ -988,22 +1002,21 @@ XlaOp XlaBuilder::ConvGeneralDilated( } StatusOr XlaBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const { const auto verify_size = [&](const size_t x, const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return Status::OK(); } else { return InvalidArgument( - "%s", tensorflow::strings::StrCat( + "%s", absl::StrCat( "Window has different number of window dimensions than of ", x_name, "\nNumber of window dimensions: ", window_dimensions.size(), - "\nNumber of ", x_name, ": ", x, "\n") - .c_str()); + "\nNumber of ", x_name, ": ", x, "\n")); } }; TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); @@ -1043,7 +1056,7 @@ StatusOr XlaBuilder::MakeWindow( } XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1084,6 +1097,23 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { "Replicated sharding is not yet supported for infeeds"); } + // Infeed takes a single token operand. Generate the token to pass to the + // infeed. + XlaOp token; + auto make_token = [&]() { + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); + }; + if (sharding()) { + // Arbitrarily assign token to device 0. + OpSharding sharding = sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, sharding); + TF_ASSIGN_OR_RETURN(token, make_token()); + } else { + TF_ASSIGN_OR_RETURN(token, make_token()); + } + // The sharding is set by the client according to the data tuple shape. // However, the shape of the infeed instruction is a tuple containing the // data and a token. For tuple sharding type, the sharding must be changed @@ -1099,11 +1129,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { sharding_builder::AssignDevice(0); XlaScopedShardingAssignment scoped_sharding(this, infeed_instruction_sharding); - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr), + HloOpcode::kInfeed, {token})); } else { - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr), + HloOpcode::kInfeed, {token})); } // The infeed instruction produces a tuple of the infed data and a token @@ -1162,15 +1192,22 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; instr.set_outfeed_config(outfeed_config); + // Outfeed takes a token as its second operand. Generate the token to pass + // to the outfeed. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + TF_RETURN_IF_ERROR( - AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token}) .status()); // The outfeed instruction produces a token. However, existing users expect @@ -1208,8 +1245,8 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; @@ -1228,7 +1265,7 @@ XlaOp XlaBuilder::CreateToken() { }); } -XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { +XlaOp XlaBuilder::AfterAll(absl::Span tokens) { return ReportErrorOrReturn([&]() -> StatusOr { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); @@ -1240,15 +1277,15 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { } XlaOp XlaBuilder::CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - if (tensorflow::str_util::StartsWith(call_target_name, "$")) { + if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", - call_target_name.c_str()); + call_target_name); } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); @@ -1256,21 +1293,8 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, }); } -XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, - const string& channel_name, - int64 cost_estimate_ns, const Shape& shape) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - *instr.mutable_shape() = shape; - instr.set_channel_name(channel_name); - instr.set_cost_estimate_ns(cost_estimate_ns); - return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands); - }); -} - -XlaOp XlaBuilder::Complex( - const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); } @@ -1279,42 +1303,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) { } XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } @@ -1322,22 +1346,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) { return UnaryOp(HloOpcode::kNot, operand); } -XlaOp XlaBuilder::ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightLogical( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, broadcast_dimensions); } @@ -1346,9 +1369,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) { return UnaryOp(HloOpcode::kAbs, operand); } -XlaOp XlaBuilder::Atan2( - const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); } @@ -1413,7 +1435,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { } XlaOp XlaBuilder::Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { + absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1428,7 +1450,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, } XlaOp XlaBuilder::Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1442,7 +1464,7 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values, +XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1470,7 +1492,7 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values, } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); } @@ -1508,10 +1530,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, return TernaryOp(HloOpcode::kClamp, min, operand, max); } -XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::Map(absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + absl::Span dimensions, + absl::Span static_operands) { return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); @@ -1520,8 +1542,8 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -1552,7 +1574,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, + absl::Span parameters, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1564,7 +1586,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, if (parameters.size() != 2) { return InvalidArgument( "RNG distribution (%s) expects 2 parameters, but got %ld", - RandomDistribution_Name(distribution).c_str(), parameters.size()); + RandomDistribution_Name(distribution), parameters.size()); } break; default: @@ -1611,27 +1633,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, }); } -XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); - TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, - GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, gather_indices_shape, - dimension_numbers, window_bounds)); + ShapeInference::InferGatherShape(input_shape, start_indices_shape, + dimension_numbers, slice_sizes)); *instr.mutable_gather_dimension_numbers() = dimension_numbers; - for (int64 bound : window_bounds) { - instr.add_gather_window_bounds(bound); + for (int64 bound : slice_sizes) { + instr.add_gather_slice_sizes(bound); } return AddInstruction(std::move(instr), HloOpcode::kGather, - {input, gather_indices}); + {input, start_indices}); }); } @@ -1693,22 +1715,39 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, }); } -XlaOp XlaBuilder::Reduce( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { +XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return Reduce(absl::Span({operand}), + absl::Span({init_value}), computation, + dimensions_to_reduce); +} + +XlaOp XlaBuilder::Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferReduceShape( - {&operand_shape, &init_shape}, dimensions_to_reduce, - called_program_shape)); + std::vector all_operands; + all_operands.insert(all_operands.end(), operands.begin(), operands.end()); + all_operands.insert(all_operands.end(), init_values.begin(), + init_values.end()); + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, + GetOperandShapes(all_operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1716,8 +1755,7 @@ XlaOp XlaBuilder::Reduce( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kReduce, - {operand, init_value}); + return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands); }); } @@ -1731,11 +1769,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, }); } -XlaOp XlaBuilder::ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { +XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1756,9 +1794,9 @@ XlaOp XlaBuilder::ReduceWindow( XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1853,8 +1891,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp XlaBuilder::CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids) { + const XlaOp& operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1862,23 +1899,24 @@ XlaOp XlaBuilder::CrossReplicaSum( b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); - return CrossReplicaSum(operand, computation, replica_group_ids, - /*channel_id=*/tensorflow::gtl::nullopt); + return CrossReplicaSum(operand, computation, replica_groups, + /*channel_id=*/absl::nullopt); }); } XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id) { + absl::Span replica_groups, + const absl::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); - for (int64 replica_group_id : replica_group_ids) { - instr.add_replica_group_ids(replica_group_id); + + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; } if (channel_id.has_value()) { @@ -1892,12 +1930,89 @@ XlaOp XlaBuilder::CrossReplicaSum( }); } -XlaOp XlaBuilder::SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { +XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + // The HloInstruction for Alltoall currently only handles the data + // communication: it accepts N already split parts and scatters them to N + // cores, and each core gathers the N received parts into a tuple as the + // output. So here we explicitly split the operand before the hlo alltoall, + // and concat the tuple elements. + // + // First, run shape inference to make sure the shapes are valid. + TF_RETURN_IF_ERROR( + ShapeInference::InferAllToAllShape(operand_shape, split_dimension, + concat_dimension, split_count) + .status()); + + // Split into N parts. + std::vector slices; + slices.reserve(split_count); + const int64 block_size = + operand_shape.dimensions(split_dimension) / split_count; + for (int i = 0; i < split_count; i++) { + slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size, + /*limit_index=*/(i + 1) * block_size, + /*stride=*/1, /*dimno=*/split_dimension)); + } + + // Handle data communication. + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); + std::vector slice_shape_ptrs; + absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + TF_ASSIGN_OR_RETURN( + XlaOp alltoall, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices)); + + // Concat the N received parts. + std::vector received; + received.reserve(split_count); + for (int i = 0; i < split_count; i++) { + received.push_back(this->GetTupleElement(alltoall, i)); + } + return this->ConcatInDim(received, concat_dimension); + }); +} + +XlaOp XlaBuilder::CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCollectivePermuteShape(operand_shape)); + + for (const auto& pair : source_target_pairs) { + auto* proto_pair = instr.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + + return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute, + {operand}); + }); +} + +XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( @@ -1910,11 +2025,10 @@ XlaOp XlaBuilder::SelectAndScatter( XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -2058,13 +2172,13 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "SendToHost shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(operand_shape).c_str()); + ShapeUtil::HumanString(operand_shape)); } if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { @@ -2103,7 +2217,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, if (!ShapeUtil::IsArray(shape)) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { @@ -2158,16 +2272,11 @@ StatusOr XlaBuilder::BuildConstantSubGraph( "of being evaluated at XLA compile time.\n\n" "Please file a usability bug with the framework being used (e.g. " "TensorFlow).", - op_string.c_str()); + op_string); } TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, LookUpInstruction(root_op)); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); - if (!CanBeRoot(opcode)) { - return InvalidArgument("the operand with opcode %s cannot be root", - root->opcode().c_str()); - } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. @@ -2177,7 +2286,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is - // also a valid denpendency order). The related ops will be added to the + // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set related_ops; tensorflow::gtl::FlatSet related_calls; // Related computations. @@ -2185,14 +2294,16 @@ StatusOr XlaBuilder::BuildConstantSubGraph( worklist.push(root->id()); related_ops.insert(root->id()); while (!worklist.empty()) { - int64 node = worklist.front(); + int64 handle = worklist.front(); worklist.pop(); - for (int64 id : instructions_[node].operand_ids()) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(handle)); + for (int64 id : instr_proto->operand_ids()) { if (related_ops.insert(id).second) { worklist.push(id); } } - for (int64 called_id : instructions_[node].called_computation_ids()) { + for (int64 called_id : instr_proto->called_computation_ids()) { related_calls.insert(called_id); } } @@ -2200,7 +2311,9 @@ StatusOr XlaBuilder::BuildConstantSubGraph( // Add related ops to the computation. for (int64 id : related_ops) { auto* instr = entry.add_instructions(); - *instr = instructions_[id]; + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, + LookUpInstructionByHandle(id)); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = StrCat(instr->name(), ".", entry.id(), ".", instr->id()); @@ -2226,7 +2339,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( std::unique_ptr XlaBuilder::CreateSubBuilder( const string& computation_name) { - auto sub_builder = MakeUnique(computation_name); + auto sub_builder = absl::make_unique(computation_name); sub_builder->parent_builder_ = this; sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; return sub_builder; @@ -2271,8 +2384,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the input are not unique: (%d, %d, %d, " + "%d)", dnum.input_batch_dimension(), dnum.input_feature_dimension(), dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); } @@ -2282,8 +2395,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.kernel_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the weight are not unique: (%d, %d, %d, " + "%d)", dnum.kernel_output_feature_dimension(), dnum.kernel_input_feature_dimension(), dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); @@ -2294,34 +2407,32 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.output_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the output are not unique: (%d, %d, %d, " + "%d)", dnum.output_batch_dimension(), dnum.output_feature_dimension(), dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); } return Status::OK(); } -StatusOr XlaBuilder::AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { +StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); - const int64 handle = instructions_.size(); + const int64 handle = GetUniqueId(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode())); + instr.set_name(instr.opcode()); } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { - return InvalidArgument("invalid XlaOp with handle %lld", - operand.handle()); + return InvalidArgument("invalid XlaOp with handle %d", operand.handle()); } if (operand.builder_ != this) { return InvalidArgument("Do not add XlaOp from builder %s to builder %s", - operand.builder_->name().c_str(), - this->name().c_str()); + operand.builder_->name(), this->name()); } instr.add_operand_ids(operand.handle()); } @@ -2331,7 +2442,8 @@ StatusOr XlaBuilder::AddInstruction( *instr.mutable_sharding() = *sharding_; } - instructions_.push_back(instr); + handle_to_index_[handle] = instructions_.size(); + instructions_.push_back(std::move(instr)); XlaOp op(handle, this); return op; @@ -2351,20 +2463,26 @@ StatusOr XlaBuilder::LookUpInstruction( if (op.builder_ == nullptr) { return InvalidArgument( - "invalid XlaOp with handle %lld; the builder of this op is freed", + "invalid XlaOp with handle %d; the builder of this op is freed", op.handle()); } if (op.builder_ != this) { return InvalidArgument( - "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name().c_str(), this->name().c_str()); + op.handle(), op.builder_->name(), this->name()); } - if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %lld", op.handle()); + return LookUpInstructionByHandle(op.handle()); +} + +StatusOr XlaBuilder::LookUpInstructionByHandle( + int64 handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); } - return &instructions_[op.handle()]; + return &instructions_[it->second]; } // Enqueues a "retrieve parameter value" instruction for a parameter that was @@ -2380,14 +2498,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions) { return operand.builder()->BroadcastInDim(operand, shape, broadcast_dimensions); } @@ -2397,26 +2513,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, return operand.builder()->Pad(operand, padding_value, padding_config); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes) { return operand.builder()->Reshape(operand, dimensions, new_sizes); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Collapse(operand, dimensions); } -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return operand.builder()->Slice(operand, start_indices, limit_indices, strides); } @@ -2428,7 +2540,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, } XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } @@ -2437,8 +2549,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); } @@ -2451,7 +2562,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { return pred.builder()->Select(pred, on_true, on_false); } -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { +XlaOp Tuple(XlaBuilder* builder, absl::Span elements) { return builder->Tuple(elements); } @@ -2460,87 +2571,98 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); } XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); } XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); } -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) { - return lhs.builder()->Dot(lhs, rhs); +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfig* precision_config) { + return lhs.builder()->Dot(lhs, rhs, precision_config); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers) { - return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { + return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, + precision_config); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return lhs.builder()->Conv(lhs, rhs, window_strides, padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count, const PrecisionConfig* precision_config) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding, + feature_group_count, precision_config); } -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralPadding( + lhs, rhs, window_strides, padding, feature_group_count, precision_config); } XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, - padding, dimension_numbers); + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, + precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, - dimension_numbers); + dimension_numbers, feature_group_count, + precision_config); } -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvGeneralDilated( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { + absl::Span fft_length) { return operand.builder()->Fft(operand, fft_type, fft_length); } @@ -2554,106 +2676,106 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, } XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return builder->Call(computation, operands); } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { + absl::Span operands, const Shape& shape) { return builder->CustomCall(call_target_name, operands, shape); } -XlaOp HostCompute(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape) { - return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape); -} - XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return real.builder()->Complex(real, imag, broadcast_dimensions); } XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->And(lhs, rhs, broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); } XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return operand.builder()->Reduce(operand, init_value, computation, dimensions_to_reduce); } +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return builder->Reduce(operands, init_values, computation, + dimensions_to_reduce); +} + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return operand.builder()->ReduceAll(operand, init_value, computation); @@ -2661,9 +2783,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { return operand.builder()->ReduceWindow(operand, init_value, computation, window_dimensions, window_strides, padding); @@ -2672,32 +2793,44 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, padding); } XlaOp CrossReplicaSum(const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids) { - return operand.builder()->CrossReplicaSum(operand, replica_group_ids); + absl::Span replica_groups) { + return operand.builder()->CrossReplicaSum(operand, replica_groups); } -XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id) { +XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, + absl::Span replica_groups, + const absl::optional& channel_id) { return operand.builder()->CrossReplicaSum(operand, computation, - replica_group_ids, channel_id); + replica_groups, channel_id); +} + +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups); +} + +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return operand.builder()->CollectivePermute(operand, source_target_pairs); } XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { return operand.builder()->SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2705,11 +2838,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return operand.builder()->SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2718,7 +2850,7 @@ XlaOp SelectAndScatterWithGeneralPadding( XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return y.builder()->Atan2(y, x, broadcast_dimensions); } @@ -2751,7 +2883,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); } @@ -2769,17 +2901,15 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { +XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); } -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { +XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, - int64 dimension) { +XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension) { return keys.builder()->Sort(keys, std::move(values), dimension); } @@ -2787,10 +2917,9 @@ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands) { return builder->Map(operands, computation, dimensions, static_operands); } @@ -2822,11 +2951,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, mantissa_bits); } -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - return input.builder()->Gather(input, gather_indices, dimension_numbers, - window_bounds); + absl::Span slice_sizes) { + return input.builder()->Gather(input, start_indices, dimension_numbers, + slice_sizes); } XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -2880,7 +3009,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens) { +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens) { return builder->AfterAll(tokens); } @@ -2907,11 +3036,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } -XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { - HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); - return builder->ReportErrorOrReturn( - builder->AddInstruction(std::move(instr), HloOpcode::kIota)); +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + return builder->Iota(type, size); +} + +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { + return builder->Iota(shape, iota_dimension); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 8726cc6f93569d94d506bcd4481c00d3427f9008..d0c59fa6f27bc265c0868734ed95a196002fbd2e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -32,8 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" @@ -154,12 +155,10 @@ class XlaBuilder { // Clears the sharding. Ops will be sharded according to the default placement // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + void ClearSharding() { sharding_ = absl::nullopt; } // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } + const absl::optional& sharding() const { return sharding_; } // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is @@ -195,9 +194,14 @@ class XlaBuilder { // Builds the computation with the requested operations, or returns a non-ok // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. + // computation being returned. The root of the computation will be the last + // added operation. StatusOr Build(); + // Overload of Build which specifies a particular root instruction for the + // computation. + StatusOr Build(XlaOp root); + // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. // This function is intended to be used where the returned XlaComputation is @@ -225,9 +229,14 @@ class XlaBuilder { // Returns the shape of the given op. StatusOr GetShape(const XlaOp& op) const; - // Returns the (inferred) result for the current computation's shape. + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. StatusOr GetProgramShape() const; + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + StatusOr GetProgramShape(XlaOp root) const; + // Reports an error to the builder, by // * storing it internally and capturing a backtrace if it's the first error // (this deferred value will be produced on the call to @@ -255,6 +264,9 @@ class XlaBuilder { StatusOr IsConstant(const XlaOp& operand) const; private: + // Build helper which takes the id of the root operation.. + StatusOr Build(int64 root_id); + // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -283,7 +295,7 @@ class XlaBuilder { template XlaOp ConstantR0(NativeT value); template - XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(absl::Span values); XlaOp ConstantR1(const tensorflow::core::Bitmap& values); template XlaOp ConstantR2( @@ -325,7 +337,7 @@ class XlaBuilder { // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -344,9 +356,8 @@ class XlaBuilder { // will generate output // [1 , 1] // [2 , 2] - XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -359,15 +370,13 @@ class XlaBuilder { // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -387,8 +396,7 @@ class XlaBuilder { // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. - XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -401,10 +409,9 @@ class XlaBuilder { // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice - XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -425,7 +432,7 @@ class XlaBuilder { // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -448,8 +455,7 @@ class XlaBuilder { // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. - XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + XlaOp ConcatInDim(absl::Span operands, int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -460,84 +466,93 @@ class XlaBuilder { XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. - XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + XlaOp Tuple(absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -559,25 +574,14 @@ class XlaBuilder { // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. // During code generation, a call instruction is emitted which targets a // symbol with the name |call_target_name|. The |operands| are passed to the // call instruction. |shape| is the resultant shape. XlaOp CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); + absl::Span operands, const Shape& shape); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -586,65 +590,70 @@ class XlaBuilder { // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + + // Reduces several arrays simultaneously among the provided dimensions, given + // "computation" as a reduction operator. + XlaOp Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -654,25 +663,23 @@ class XlaBuilder { // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. - XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids = {}); + XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -681,27 +688,38 @@ class XlaBuilder { // scalars, e.g., add, min, or max. The way that AllReduce is applied is // configured by: // - // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, - // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group. Allreduce will be applied within + // subgroups. For example, we have 4 replicas, then + // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, + // replica 1 and 3 are in subgroup 1. // - // - `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // - `channel_id`: for Allreduce nodes from different modules, if they have + // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids = {}, - const tensorflow::gtl::optional& channel_id = - tensorflow::gtl::nullopt); + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt); + + // Enqueues an operation that do an Alltoall of the operand cross cores. + XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + + // Enqueues an operation that do an CollectivePermute of the operand cross + // cores. + XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); @@ -710,18 +728,17 @@ class XlaBuilder { // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -768,7 +785,7 @@ class XlaBuilder { // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -776,6 +793,12 @@ class XlaBuilder { // entry was NaN. XlaOp IsFinite(const XlaOp& operand); + // Enqueues an iota operation onto the computation. + XlaOp Iota(const Shape& shape, int64 iota_dimension); + + // Enqueues a rank-1 iota operation onto the computation. + XlaOp Iota(PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, @@ -792,14 +815,12 @@ class XlaBuilder { XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. - XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); + XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). - XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -817,18 +838,16 @@ class XlaBuilder { // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and a tensor with their // corresponding values as the second element. - XlaOp Sort(XlaOp keys, - tensorflow::gtl::optional values = tensorflow::gtl::nullopt, + XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. - XlaOp Map(tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); + XlaOp Map(absl::Span operands, const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -853,9 +872,9 @@ class XlaBuilder { const int mantissa_bits); // Enqueues a Gather node onto the computation. - XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -883,7 +902,7 @@ class XlaBuilder { // Enqueues an AfterAll operation with no operands producing a token-shaped // value. - XlaOp AfterAll(tensorflow::gtl::ArraySlice tokens); + XlaOp AfterAll(absl::Span tokens); // Enqueues a Recv node onto the computation. The data comes from a Send // instruction that shares the same channel handle and its shape must @@ -930,14 +949,15 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands = {}); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands = {}); void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); StatusOr LookUpInstruction(const XlaOp& op) const; + StatusOr LookUpInstructionByHandle( + int64 handle) const; // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -946,19 +966,17 @@ class XlaBuilder { // broadcast_dimensions specifies which dimensions to use for broadcasting // when the operation is between tensors of different ranks. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs); XlaOp RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); + absl::Span parameters, const Shape& shape); - StatusOr InDimBroadcast( - const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + StatusOr InDimBroadcast(const Shape& shape, const XlaOp& operand, + absl::Span broadcast_dimensions); // Internal helper method that creates a sequence of instructions that // performs an explicit broadcast of the operand to the target shape. @@ -969,13 +987,12 @@ class XlaBuilder { // shape. StatusOr Reshape(const Shape& shape, const XlaOp& operand); - // Returns the (inferred) result for the program shape for the current - // computation and fills the root_id in the pointer. - StatusOr GetProgramShape(int64* root_id) const; + // Returns the (inferred) result for the program shape using the given root. + StatusOr GetProgramShape(int64 root_id) const; // Returns shapes for the operands. StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const; + absl::Span operands) const; // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful @@ -992,12 +1009,11 @@ class XlaBuilder { // Helper function for creating a Window proto from user-supplied data. // Returns error if the user-supplied data was invalid. - StatusOr MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const; + StatusOr MakeWindow(absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const; string name_; // Name to use for the built computation. @@ -1011,6 +1027,10 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + tensorflow::gtl::FlatMap handle_to_index_; + // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of // that XlaComputation. @@ -1026,7 +1046,7 @@ class XlaBuilder { // Sharding for this operator. This is structured as a "model"-like operation, // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; + absl::optional sharding_; // Mode bit that indicates whether to die when a first error is encountered. bool die_immediately_on_error_ = false; @@ -1041,7 +1061,7 @@ class XlaBuilder { friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template friend XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); + absl::Span values); friend XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template @@ -1081,175 +1101,180 @@ class XlaBuilder { friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); friend XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); friend XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - int64 dimension); + absl::Span operands, int64 dimension); friend void Trace(const string& tag, const XlaOp& operand); friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - friend XlaOp Tuple(XlaBuilder* builder, - tensorflow::gtl::ArraySlice elements); + friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + absl::Span broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - friend XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config); + friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - friend XlaOp HostCompute(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); + absl::Span operands, const Shape& shape); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Not(const XlaOp& operand); - friend XlaOp ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - friend XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - friend XlaOp ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - friend XlaOp CrossReplicaSum( + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + const XlaComputation& computation, + absl::Span replica_groups, + const absl::optional& channel_id); + friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + friend XlaOp CollectivePermute( const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id); - friend XlaOp SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + const std::vector>& source_target_pairs); + friend XlaOp SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); friend XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); friend XlaOp Abs(const XlaOp& operand); friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Exp(const XlaOp& operand); friend XlaOp Expm1(const XlaOp& operand); friend XlaOp Floor(const XlaOp& operand); @@ -1265,28 +1290,25 @@ class XlaBuilder { friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); - // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota - // in xla/client/lib/numeric.h with this (renamed to xla::Iota). - friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, + int64 iota_dimension); + friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp Neg(const XlaOp& operand); friend XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); - friend XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); - friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, - int64 dimension); + absl::Span permutation); + friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); + friend XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - friend XlaOp Map(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, + friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + absl::Span dimensions, + absl::Span static_operands); friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); @@ -1298,9 +1320,9 @@ class XlaBuilder { const XlaComputation& false_computation); friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, @@ -1334,8 +1356,7 @@ class XlaBuilder { const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); - friend XlaOp AfterAll(XlaBuilder* builder, - tensorflow::gtl::ArraySlice tokens); + friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1343,7 +1364,7 @@ class XlaBuilder { class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, - tensorflow::gtl::optional sharding) + absl::optional sharding) : builder_(builder), prev_sharding_(builder->sharding()) { SetSharding(sharding); } @@ -1355,7 +1376,7 @@ class XlaScopedShardingAssignment { ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } private: - void SetSharding(const tensorflow::gtl::optional& sharding) { + void SetSharding(const absl::optional& sharding) { if (sharding.has_value()) { builder_->SetSharding(sharding.value()); } else { @@ -1364,7 +1385,7 @@ class XlaScopedShardingAssignment { } xla::XlaBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; + absl::optional prev_sharding_; }; // Free functions for building XlaOps. The intention is that these will @@ -1399,8 +1420,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); template XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template XlaOp ConstantR2(XlaBuilder* builder, @@ -1449,8 +1469,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -1469,9 +1488,8 @@ XlaOp Broadcast(const XlaOp& operand, // will generate output // [1 , 1] // [2 , 2] -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -1484,15 +1502,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -1512,8 +1528,7 @@ XlaOp Reshape(const XlaOp& operand, // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -1526,10 +1541,9 @@ XlaOp Collapse(const XlaOp& operand, // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -1550,7 +1564,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -1573,8 +1587,8 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, int64 dimension); +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, + int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -1585,82 +1599,91 @@ void Trace(const string& tag, const XlaOp& operand); XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); +XlaOp Tuple(XlaBuilder* builder, absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -1692,26 +1715,14 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, // Enqueues a call instruction onto the computation. XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. // During code generation, a call instruction is emitted which targets a // symbol with the name |call_target_name|. The |operands| are passed to the // call instruction. |shape| is the resultant shape. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - -// Enqueues a pseudo-op to represent host-side computation data-dependencies. -// During code generation, host send and receive operations will be generated -// to transfer |operands| to the host and a single result of |shape| back to -// the device. Host send/recv operations are emitted using |channel_name|. -// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO -// instruction scheduling. -XlaOp HostCompute(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); + absl::Span operands, const Shape& shape); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1720,65 +1731,70 @@ XlaOp HostCompute(XlaBuilder* builder, // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -1788,25 +1804,23 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. -XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids = {}); +XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -1815,45 +1829,61 @@ XlaOp CrossReplicaSum( // scalars, e.g., add, min, or max. The way that AllReduce is applied is // configured by: // -// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all -// replicas belong to one group. Allreduce will be applied within subgroups. -// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, -// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// - `replica_groups`: each ReplicaGroup contains a list of replica id. If +// empty, all replicas belong to one group. Allreduce will be applied within +// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} +// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // -// - `channel_id`: for Allreduce nodes from different models, if they have the +// - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be -// applied cross models. +// applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. -XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids = {}, - const tensorflow::gtl::optional& - channel_id = tensorflow::gtl::nullopt); +XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt); + +// Enqueues an operation that do an Alltoall of the operand cross cores. +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups = {}); + +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); // As SelectAndScatter(), but the padding is given in the format // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -1900,7 +1930,7 @@ XlaOp Imag(const XlaOp& operand); // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -1908,6 +1938,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, // entry was NaN. XlaOp IsFinite(const XlaOp& operand); +// Enqueues an iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); @@ -1922,13 +1958,12 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); +XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); +XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -1946,18 +1981,16 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and a tensor with their // corresponding values as the second element. -XlaOp Sort(XlaOp keys, - tensorflow::gtl::optional values = tensorflow::gtl::nullopt, +XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -1982,9 +2015,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); // Enqueues a Gather node onto the computation. -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -2042,7 +2075,7 @@ XlaOp CreateToken(XlaBuilder* builder); // Enqueues an AfterAll instruction which produces a token-shaped value and // takes a variadic number of token-shaped operands. The number of operands must // be greater than zero. Used for joining tokens. -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens); +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -2086,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*LiteralUtil::CreateR0(value)); + return ConstantLiteral(LiteralUtil::CreateR0(value)); } template -XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); +XlaOp XlaBuilder::ConstantR1(absl::Span values) { + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template @@ -2103,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*LiteralUtil::CreateR2(values)); + return ConstantLiteral(LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); + return ConstantLiteral(LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); + return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -2163,13 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); } template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template @@ -2182,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); } template @@ -2196,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateFromArray(values)); + LiteralUtil::CreateFromArray(values)); } template @@ -2211,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateR2FromArray2D(values)); + LiteralUtil::CreateR2FromArray2D(values)); } template @@ -2228,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 28a207b137d901213ec43d506a638ef08a6bded9..7c37ed00cd3dcc214fb0b36c0161d3c39a5bf8c8 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -46,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test { return HloModule::CreateFromProto(proto, config); } + // Overload which explicitly specifies the root instruction. + StatusOr> BuildHloModule(XlaBuilder* b, + XlaOp root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + // Returns the name of the test currently being run. string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); @@ -293,6 +305,30 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllToAll) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); +} + +TEST_F(XlaBuilderTest, CollectivePermute) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); @@ -320,5 +356,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); } +TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { + XlaBuilder b(TestName()); + XlaOp constant = ConstantR0(&b, 1.0); + Add(constant, ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { + // Specifying a particular root in Build should still include all entry + // parameters. + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + XlaOp x = Parameter(&b, 0, shape, "x"); + XlaOp y = Parameter(&b, 1, shape, "y"); + XlaOp z = Parameter(&b, 2, shape, "z"); + Add(x, Sub(y, z)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter()); + EXPECT_EQ(module->entry_computation()->num_parameters(), 3); + EXPECT_EQ(module->entry_computation()->instruction_count(), 5); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { + XlaBuilder b(TestName()); + XlaBuilder other_b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + + Parameter(&b, 0, shape, "param"); + XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); + + Status status = b.Build(other_param).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("root operation is not in this computation")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD deleted file mode 100644 index 2e131dbad26970d4cb9860c17c3de3d52de36223..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -# Description: -# The new XLA client libraries. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "xla_builder", - hdrs = ["xla_builder.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/xla/client:xla_builder", - ], -) diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 3543d41fc2656ec028646edebc0bf5b6af7f67a5..22c9e83bb2ae9e3e205bdd480b64c703e31c6ffd 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -32,7 +32,7 @@ StatusOr> XlaComputation::Snapshot() const { if (IsNull()) { return InvalidArgument("Computation is invalid."); } - auto session = MakeUnique(); + auto session = absl::make_unique(); *session->mutable_hlo()->mutable_hlo_module() = proto_; return std::move(session); } diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h index 1a51fdee680721a4a03fa5de79a81746d92af76b..6d51126d882f87a84b054e9db599b995868824bf 100644 --- a/tensorflow/compiler/xla/device_util.h +++ b/tensorflow/compiler/xla/device_util.h @@ -21,8 +21,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,8 +30,8 @@ namespace xla { // Returns a string that represents the device in terms of platform and ordinal; // e.g. the first CUDA device will be "cuda:0" string DeviceIdentifier(se::StreamExecutor* stream_exec) { - return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", - stream_exec->device_ordinal()); + return absl::StrCat(stream_exec->platform()->Name(), ":", + stream_exec->device_ordinal()); } } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index a472747bd174e3bbd352f07f2ab092e678b81073..0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } +ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream( + stream_executor::Stream* stream) { + host_to_device_stream_ = stream; + return *this; +} + +stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const { + return host_to_device_stream_; +} + ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 416131be006e6ecddb47651f8b684c1d91df4892..ba3217f31b55bd1428f67da6154a46c8bc304053 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -65,6 +65,13 @@ class ExecutableRunOptions { ExecutableRunOptions& set_stream(stream_executor::Stream* stream); stream_executor::Stream* stream() const; + // If set, this is the stream to perform any pre-computation transfers on. + // The platform of the stream must match the platform the executable was + // built for. A value of nullptr indicates the option has not been set. + ExecutableRunOptions& set_host_to_device_stream( + stream_executor::Stream* stream); + stream_executor::Stream* host_to_device_stream() const; + // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. ExecutableRunOptions& set_intra_op_thread_pool( @@ -90,6 +97,7 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; + stream_executor::Stream* host_to_device_stream_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index ffd1fb79e986f82e1c2721f0eefbf3b4c0838e41..3fadabcf5207097aa875d654320b930b1ed94ad3 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -18,16 +18,16 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index) { + const Shape& shape, absl::Span multi_index) { DCHECK_EQ(shape.dimensions_size(), multi_index.size()); // Padding and nested layouts not supported yet. DCHECK_EQ(0, shape.layout().padded_dimensions_size()); @@ -36,7 +36,7 @@ namespace xla { DCHECK_GE(multi_index[i], 0); DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" - << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",") + << "\n\tindex: " << absl::StrJoin(multi_index, ",") << "\n\tshape: " << ShapeUtil::HumanString(shape); } @@ -118,8 +118,8 @@ namespace xla { return multi_index; } -/* static */ bool IndexUtil::BumpIndices( - const Shape& shape, tensorflow::gtl::MutableArraySlice indices) { +/* static */ bool IndexUtil::BumpIndices(const Shape& shape, + absl::Span indices) { for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) { int64 limit = shape.dimensions(dimno); if (indices[dimno] + 1 < limit) { @@ -149,8 +149,8 @@ namespace xla { return stride; } -/* static */ bool IndexUtil::IndexInBounds( - const Shape& shape, tensorflow::gtl::ArraySlice index) { +/* static */ bool IndexUtil::IndexInBounds(const Shape& shape, + absl::Span index) { int64 rank = ShapeUtil::Rank(shape); if (rank != index.size()) { return false; @@ -163,9 +163,8 @@ namespace xla { return true; } -/* static */ int IndexUtil::CompareIndices( - tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs) { +/* static */ int IndexUtil::CompareIndices(absl::Span lhs, + absl::Span rhs) { int64 rank = lhs.size(); CHECK_EQ(rhs.size(), rank); for (int64 dim = 0; dim < rank; ++dim) { diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 142006f2626e83d3254f2de65fc28fd5d6694e53..2979cf87dde92893ce2151cb09b46c8db8473b31 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -35,7 +35,7 @@ class IndexUtil { // on the shape and its layout. The first index in the multi_index is // dimension 0. static int64 MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index); + const Shape& shape, absl::Span multi_index); // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional @@ -58,8 +58,7 @@ class IndexUtil { // // Returns true iff the indices were successfully bumped; false if we've hit // the limit where it can no longer be bumped in-bounds. - static bool BumpIndices(const Shape& shape, - tensorflow::gtl::MutableArraySlice indices); + static bool BumpIndices(const Shape& shape, absl::Span indices); // Calculates the stride size (in number of elements, not byte size) of a // given logical shape dimension (from 0 to rank-1). If available, padded @@ -71,15 +70,14 @@ class IndexUtil { // Returns true iff the given multi-index is contained in the bounds for the // shape. - static bool IndexInBounds(const Shape& shape, - tensorflow::gtl::ArraySlice index); + static bool IndexInBounds(const Shape& shape, absl::Span index); // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is // returned. - static int CompareIndices(tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs); + static int CompareIndices(absl::Span lhs, + absl::Span rhs); private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 7c4efdee484d9530a69b31cbe3a0d69a8a3cffa7..93522d2ca87a7eba8d3c7533785c54e63ce507b0 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -142,13 +142,13 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { TEST(IndexUtilTest, BumpIndices2x2) { auto shape = ShapeUtil::MakeShape(S32, {2, 2}); std::vector indices = {0, 0}; - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(0, 1)); - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(1, 0)); - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(1, 1)); - EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_FALSE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); } } // namespace diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h index a8bb8c7a7e6784e555f4e9dad73ecc78c668ac42..3a3ee21e7635b9dee61f59e4e8c69eec3d420c86 100644 --- a/tensorflow/compiler/xla/iterator_util.h +++ b/tensorflow/compiler/xla/iterator_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_ #include #include @@ -95,4 +95,4 @@ UnwrappingIterator MakeUnwrappingIterator(NestedIter iter) { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc index 7bc3189507ec5233c6983eb26cfb07dc9bfadd52..ec8b66df2db0b9d8c045fbf6133f607e57c81c26 100644 --- a/tensorflow/compiler/xla/iterator_util_test.cc +++ b/tensorflow/compiler/xla/iterator_util_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/test.h" namespace xla { @@ -27,7 +27,7 @@ namespace { TEST(UnwrappingIteratorTest, Simple) { std::vector> v; for (int i = 0; i < 3; ++i) { - v.push_back(MakeUnique(i)); + v.push_back(absl::make_unique(i)); } int i = 0; for (auto iter = MakeUnwrappingIterator(v.begin()); @@ -51,7 +51,7 @@ TEST(UnwrappingIteratorTest, PostincrementOperator) { TEST(UnwrappingIteratorTest, StdFind) { std::list> l; for (int i = 0; i < 3; ++i) { - l.push_back(MakeUnique(i)); + l.push_back(absl::make_unique(i)); } EXPECT_EQ(l.begin()->get(), *std::find(MakeUnwrappingIterator(l.begin()), diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index b72d190d54591384392e79e73e90cf52df04a902..d310335618ded7b581e6ed632223218585bb791f 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer( } // namespace /* static */ Layout LayoutUtil::MakeLayout( - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span minor_to_major) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { @@ -66,7 +66,7 @@ void SetDefaultLayoutToContainer( } /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor) { + absl::Span major_to_minor) { Layout layout; layout.set_format(DENSE); for (int i = major_to_minor.size() - 1; i >= 0; i--) { @@ -169,7 +169,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return ValidateLayoutForShape(shape.layout(), shape); } else { @@ -177,7 +177,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (shape.has_layout()) { return InvalidArgument( "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -194,7 +194,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.padded_dimensions_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -202,17 +202,17 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.format() == INVALID_FORMAT) { return InvalidArgument( "Layout does not have a valid format: layout {%s}, shape {%s}", - layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); + layout.ShortDebugString(), shape.ShortDebugString()); } if (layout.format() == DENSE) { if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", + "but shape is rank %d: {%s}; shape: %s", layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + absl::StrJoin(layout.minor_to_major(), ", "), + shape.ShortDebugString()); } std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); @@ -221,12 +221,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + HumanString(layout)); } if (dimensions_in_layout[dim]) { return InvalidArgument( "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); + HumanString(layout)); } dimensions_in_layout[dim] = true; } @@ -234,14 +234,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.padded_dimensions_size() > 0) { if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", + "layout has %d padded dimensions, but shape is rank %d", layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); } for (int i = 0; i < layout.padded_dimensions_size(); ++i) { if (layout.padded_dimensions(i) < shape.dimensions(i)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", + "for dimension %d, dimension padding (%d) is smaller than " + "the dimension size (%d) of the shape", i, layout.padded_dimensions(i), shape.dimensions(i)); } } @@ -307,7 +307,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( +/* static */ absl::Span LayoutUtil::PaddedDimensions( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); @@ -363,13 +363,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().minor_to_major()); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Layout& layout) { CHECK(layout.format() == DENSE); return AsInt64Slice(layout.minor_to_major()); @@ -403,12 +403,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ string LayoutUtil::HumanString(const Layout& layout) { if (IsSparse(layout)) { - return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), - "}"); + return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); } CHECK(IsDense(layout)); - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); + return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); } namespace { @@ -474,7 +472,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } /* static */ bool LayoutUtil::AreDimensionsConsecutive( - const Layout& layout, tensorflow::gtl::ArraySlice dims) { + const Layout& layout, absl::Span dims) { CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 739bbe73675c7fb855627006028eafdf703d6540..b78883c2d870043032306637730c4666665125a8 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,10 +20,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -34,11 +34,11 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor); + absl::Span major_to_minor); // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) @@ -104,8 +104,7 @@ class LayoutUtil { // Returns the padded_dimensions array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice PaddedDimensions( - const Shape& shape); + static absl::Span PaddedDimensions(const Shape& shape); // Returns the given index of the padded_dimensions array for the given Shape. // Requires that the shape is an array and has a dense layout. @@ -138,8 +137,8 @@ class LayoutUtil { // Returns the minor_to_major array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); - static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + static absl::Span MinorToMajor(const Shape& shape); + static absl::Span MinorToMajor(const Layout& layout); // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. @@ -196,7 +195,7 @@ class LayoutUtil { // Returns whether the given dimensions are consecutive in the given layout, // not necessarily in the order given. static bool AreDimensionsConsecutive(const Layout& layout, - tensorflow::gtl::ArraySlice dims); + absl::Span dims); // Compute a hash for `layout`. static size_t Hash(const Layout& layout); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index e4c825450dcd45a8fbeaacbb2ad145f94307176f..f25dae6ff411133c74502039f441060f1329ffd4 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -27,15 +27,15 @@ namespace { class LayoutUtilTest : public ::testing::Test { protected: Shape MakeShapeWithLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span dimensions, + absl::Span minor_to_major) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } Shape MakeShapeWithSparseLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, + absl::Span dimensions, int64 max_sparse_elements) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 89353448e29ec3d97275dac288e23aa8e96e31b2..3e79129aafd234e5eab05d205f2017b54057795e 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -39,6 +40,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", ], ) @@ -56,6 +58,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -73,5 +76,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f42fb92359f40ec763866af094972046f6407ae1..3ed3afcfcede20fbf5c7d4f004378817febeb4c7 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -17,9 +17,9 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace legacy_flags { @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); flags->set_xla_llvm_enable_invariant_load_metadata(true); @@ -53,6 +52,13 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // the heuristics needed to decide when to run on multiple streams. See // b/77879207. flags->set_xla_gpu_disable_multi_streaming(true); + + // TODO(jlebar): Disable fastmath once doing so is not a performance + // regression. + flags->set_xla_cpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_math(true); + + flags->set_xla_force_host_platform_device_count(1); } // Allocates flag_values and flag_objects; this function must not be called more @@ -83,7 +89,7 @@ void AllocateFlags() { // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector disabled_passes = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); for (const auto& passname : disabled_passes) { flag_values->add_xla_disable_hlo_passes(passname); } @@ -150,10 +156,16 @@ void AllocateFlags() { flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_gpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the GPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", @@ -306,6 +318,24 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), flag_values->xla_cpu_use_mkl_dnn(), "Generate calls to MKL-DNN in the CPU backend."), + tensorflow::Flag( + "xla_gpu_crash_on_verification_failures", + bool_setter_for( + &DebugOptions::set_xla_gpu_crash_on_verification_failures), + flag_values->xla_gpu_crash_on_verification_failures(), + "Crashes the program on extra verification failures, e.g. cuDNN " + "cross checking failures"), + tensorflow::Flag( + "xla_force_host_platform_device_count", + int32_setter_for( + &DebugOptions::set_xla_force_host_platform_device_count), + flag_values->xla_force_host_platform_device_count(), + "Force the host platform to pretend that there are these many " + "host \"devices\". All of these host devices are backed by the same" + "threadpool. Setting this to anything other than 1 can increase " + "overhead from context switching but we let the user override this " + "behavior to help run tests on the host that run models in parallel " + "across multiple devices."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index e9cf435d83d8345e974d83f8e5340dafeba8e3b2..ee7eb019c07cf898e48886955b18710146644cac 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ #include +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace legacy_flags { @@ -30,7 +30,7 @@ template void parse_xla_backend_extra_options(T* extra_options_map, string comma_separated_values) { std::vector extra_options_parts = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); // The flag contains a comma-separated list of options; some options // have arguments following "=", some don't. @@ -59,8 +59,7 @@ void parse_xla_backend_extra_options(T* extra_options_map, inline bool parse_xla_reduce_precision_option( HloReducePrecisionOptions* options, string option_string) { // Split off "LOCATION" from remainder of string. - std::vector eq_split = - tensorflow::str_util::Split(option_string, '='); + std::vector eq_split = absl::StrSplit(option_string, '='); if (eq_split.size() != 2) { return false; } @@ -80,26 +79,25 @@ inline bool parse_xla_reduce_precision_option( } // Split off "E,M" from remainder of string. - std::vector colon_split = - tensorflow::str_util::Split(eq_split[1], ':'); + std::vector colon_split = absl::StrSplit(eq_split[1], ':'); if (colon_split.size() != 2) { return false; } // Split E and M, and parse. std::vector bitsizes; - if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',', - &bitsizes) || - bitsizes.size() != 2) { - return false; + for (const auto& s : absl::StrSplit(colon_split[0], ',')) { + bitsizes.emplace_back(); + if (!absl::SimpleAtoi(s, &bitsizes.back())) { + return false; + } } options->set_exponent_bits(bitsizes[0]); options->set_mantissa_bits(bitsizes[1]); // Split off OPS comma-separated list from remainder of string, if the // remainder exists. - std::vector semicolon_split = - tensorflow::str_util::Split(colon_split[1], ';'); + std::vector semicolon_split = absl::StrSplit(colon_split[1], ';'); if (semicolon_split.size() > 2) { return false; } @@ -113,8 +111,7 @@ inline bool parse_xla_reduce_precision_option( options->add_opcodes_to_suffix(i); } } else { - std::vector opcodes = - tensorflow::str_util::Split(opcode_string, ','); + std::vector opcodes = absl::StrSplit(opcode_string, ','); for (const string& opcode : opcodes) { bool found = false; for (int i = 0; i < HloOpcodeCount(); i++) { @@ -132,8 +129,7 @@ inline bool parse_xla_reduce_precision_option( // Process the NAMES string, if it exists. if (semicolon_split.size() == 2) { - std::vector opnames = - tensorflow::str_util::Split(semicolon_split[1], ','); + std::vector opnames = absl::StrSplit(semicolon_split[1], ','); for (const string& opname : opnames) { if (opname.length() > 0) { options->add_opname_substrings_to_suffix(opname); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc index 0ed788a9676fe9b1bd06fb3ceabf627c108a2c70..6f197aec53c7596e84437a03affa9118f22f5a1d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index 7b6ae311c1099dccb8dceb2f49743c1b185cd5ab..138c0c852e2bb0527d171f25b4d96cedc5671516 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" @@ -106,8 +106,8 @@ TEST(ParseFlagsFromEnv, File) { if (tmp_dir == nullptr) { tmp_dir = kTempDir; } - string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d", - tmp_dir, getpid()); + string tmp_file = + absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid()); FILE* fp = fopen(tmp_file.c_str(), "w"); CHECK_NE(fp, nullptr) << "can't write to " << tmp_file; for (int i = 0; kTestFlagString[i] != '\0'; i++) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 0545deb096e9eace5a9713f200e10559aa718441..5035f4198890857fcafd0156d7eaeeb4bc164322 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -22,6 +22,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,19 +34,15 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; // Converts between little and big endian. @@ -71,9 +71,9 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { return out; } -Literal::StrideConfig::StrideConfig( +MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : dimensions(dimensions), base(dimensions.size(), 0), step(dimensions.size(), 1) { @@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique(shape)) { + : MutableLiteralBase() { + shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() { }); } -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } +Literal::Literal(Literal&& other) : MutableLiteralBase() { + *this = std::move(other); +} Literal& Literal::operator=(Literal&& other) { DCHECK(&other.root_piece_->subshape() == other.shape_.get()); @@ -171,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) { return *this; } -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); - literal->root_piece_->ForEachMutableSubpiece( +Literal LiteralBase::CreateFromShape(const Shape& shape) { + Literal literal(shape); + literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { memset(piece->untyped_data(), 0, piece->size_bytes()); @@ -187,20 +190,20 @@ const SparseIndexArray* LiteralBase::sparse_indices( return piece(shape_index).sparse_indices(); } -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { +SparseIndexArray* MutableLiteralBase::sparse_indices( + const ShapeIndex& shape_index) { return piece(shape_index).sparse_indices(); } template -Status Literal::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFromInternal( + const LiteralBase& src_literal, absl::Span src_base, + absl::Span dest_base, absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; @@ -225,10 +228,10 @@ Status Literal::CopySliceFromInternal( // proper stride size at the matching dimension. DimensionVector src_indexes(src_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); + MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + auto copy_proc = [&](absl::Span indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -253,9 +256,9 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { +Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, + absl::Span src_index, + absl::Span dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -275,7 +278,7 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr> Literal::CreateFromProto( +/* static */ StatusOr MutableLiteralBase::CreateFromProto( const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); @@ -284,9 +287,9 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no layout"); } - auto literal = MakeUnique(proto.shape()); + Literal literal(proto.shape()); - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { const LiteralProto* proto_element = &proto; for (int64 i : index) { @@ -298,7 +301,7 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", + "Expected %d tuple elements in LiteralProto, has %d", ShapeUtil::TupleElementCount(piece->subshape()), proto_element->tuple_literals_size()); } @@ -350,9 +353,9 @@ namespace { // Copies the elements in 'src' to 'dest'. The shape and layout of the data in // the array slices are indicated by dest_shape and src_shape respectively. template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { +void CopyElementsBetween(absl::Span dest, + absl::Span src, const Shape& dest_shape, + const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; @@ -361,7 +364,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; - } while (IndexUtil::BumpIndices(dest_shape, &index)); + } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index))); } } // namespace @@ -399,15 +402,15 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { default: return Unimplemented( "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); + PrimitiveType_Name(subshape().element_type())); } } return Status::OK(); } -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { +Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { const Shape& dest_subshape = ShapeUtil::GetSubshape(shape(), dest_shape_index); const Shape& src_subshape = @@ -415,8 +418,8 @@ Status Literal::CopyFrom(const LiteralSlice& src_literal, if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { return InvalidArgument( "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_subshape)); } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -453,8 +456,8 @@ Status Literal::MoveFrom(Literal&& src_literal, if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { return InvalidArgument( "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_literal.shape())); } src_literal.root_piece_->ForEachSubpiece( @@ -474,7 +477,7 @@ Status Literal::MoveFrom(Literal&& src_literal, dest_piece.set_sparse_indices(src_piece.sparse_indices()); }); - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + src_literal.shape_ = absl::make_unique(ShapeUtil::MakeNil()); delete src_literal.root_piece_; src_literal.root_piece_ = new LiteralBase::Piece(); src_literal.root_piece_->set_subshape(src_literal.shape_.get()); @@ -482,10 +485,10 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -543,7 +546,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, shape().element_type()); } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { +void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(element_count(), values.bits()); @@ -553,40 +556,38 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { } } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { +Literal LiteralBase::Relayout(const Layout& new_layout, + const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { +Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) << " not compatible with literal shape " << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); + Literal result = CreateFromShape(shape_with_layout); ShapeUtil::ForEachSubshape( - result->shape(), + result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); + TF_CHECK_OK(result.CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); } }); return result; } -StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { +StatusOr LiteralBase::Broadcast( + const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); } @@ -596,20 +597,20 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = MakeUnique(result_shape); + Literal result(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in // every iteration of ShapeUtil::ForEachIndex. std::vector scratch_source_index(shape().dimensions_size()); - char* dest_data = static_cast(result->untyped_data()); + char* dest_data = static_cast(result.untyped_data()); const char* source_data = static_cast(untyped_data()); const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + result_shape, [&](absl::Span output_index) { for (int64 i = 0; i < dimensions.size(); ++i) { scratch_source_index[i] = output_index[dimensions[i]]; } @@ -625,37 +626,36 @@ StatusOr> LiteralBase::Broadcast( return std::move(result); } -StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { +StatusOr LiteralBase::Reshape( + absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } - std::unique_ptr output; + Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { - output = CloneToUnique(); + output = Clone(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape_do_not_use() = + *output.mutable_shape_do_not_use() = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + int64 elements_after = ShapeUtil::ElementsIn(output.shape()); if (elements_before != elements_after) { return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); + ShapeUtil::HumanString(shape()), + ShapeUtil::HumanString(output.shape())); } return std::move(output); } -std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { +Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -685,33 +685,31 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = MakeUnique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + Literal new_literal(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); return new_literal; } template -std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); +Literal LiteralBase::SliceInternal( + const Shape& result_shape, absl::Span start_indices) const { + Literal result_literal(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + result_literal.EachCell( + [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); - result_literal->Set(indices, value); + result_literal.Set(indices, value); }); return result_literal; } -std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { +Literal LiteralBase::Slice(absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -749,13 +747,7 @@ Literal LiteralBase::Clone() const { return result; } -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, +string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); @@ -852,7 +844,7 @@ string LiteralBase::GetSparseElementAsString( } StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { + absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -868,9 +860,8 @@ StatusOr LiteralBase::GetIntegralAsS64( case U64: return Get(multi_index); default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } } @@ -895,8 +886,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value) { +Status MutableLiteralBase::SetIntegralAsS64(absl::Span multi_index, + int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -918,14 +909,13 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, Set(multi_index, value); break; default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } return Status::OK(); } -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( +absl::Span LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -933,7 +923,7 @@ tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( return p.sparse_indices()->At(sparse_element_number); } -void Literal::SortSparseElements(const ShapeIndex& shape_index) { +void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } @@ -994,7 +984,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() { auto values = data(); CHECK_LE(num_elements, values.size()); sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); + absl::Span(values.data(), num_elements)); } namespace { @@ -1023,9 +1013,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, element_index.push_back(i); std::vector element_pieces; ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); pieces->push_back("\n)"); return; } @@ -1049,8 +1039,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(": "); } else { pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); pieces->push_back("]: "); } pieces->push_back(literal.GetSparseElementAsString(i)); @@ -1061,8 +1050,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, CHECK(LayoutUtil::IsDenseArray(subshape)); - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { + auto element_to_string = [&](absl::Span indices) -> string { PrimitiveType element_type = subshape.element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. @@ -1111,9 +1099,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { pieces->push_back(" {"); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { @@ -1131,11 +1119,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { pieces->push_back(" {"); for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { @@ -1157,7 +1145,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {"); literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { + [&](absl::Span indices, const string& value) { pieces->push_back(" "); pieces->push_back(value); }); @@ -1176,11 +1164,11 @@ string LiteralBase::ToString(bool print_layout) const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } void LiteralBase::EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -1189,19 +1177,19 @@ void LiteralBase::EachCellAsString( shape(), /*linear_index=*/0); do { per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); + } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); } namespace { template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { +Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, + const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); + auto dest_data = result_literal.template data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { @@ -1211,8 +1199,7 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1220,7 +1207,7 @@ std::unique_ptr ConvertBetweenNativeTypes( template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); @@ -1235,22 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { // identical sizes higher up. template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { +Literal ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( + Literal result_literal( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); + absl::Span src_data = src_literal.data(); + absl::Span dest_data = result_literal.data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1259,8 +1244,7 @@ std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { +Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { return BitcastBetweenNativeTypes< @@ -1278,9 +1262,9 @@ std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, } template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, + PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ case (type): \ @@ -1307,18 +1291,17 @@ StatusOr> ConvertIfDestTypeMatches( default: break; } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); + return literal.Clone(); } switch (literal.shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ @@ -1339,47 +1322,37 @@ StatusOr> ConvertSwitch( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } } } // namespace -StatusOr> LiteralBase::Convert( +StatusOr LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> LiteralBase::BitcastConvert( +StatusOr LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { return InvalidArgument( "Cannot bitcast convert from %s to %s, bit widths are different: %d != " "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), + PrimitiveType_Name(shape().element_type()), + PrimitiveType_Name(primitive_dest_type), primitive_util::BitWidth(shape().element_type()), primitive_util::BitWidth(primitive_dest_type)); } return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { +StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } return Convert(dest_shape.element_type()); } std::vector elements; @@ -1388,15 +1361,13 @@ StatusOr> LiteralBase::ConvertToShape( TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); + elements.push_back(std::move(new_element)); } - auto converted = MakeUnique(); - *converted = Literal::MoveIntoTuple(&elements); - return std::move(converted); + return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); } -/* static */ Literal Literal::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { +/* static */ Literal MutableLiteralBase::MoveIntoTuple( + absl::Span elements) { std::vector element_shapes; for (const Literal& element : elements) { element_shapes.push_back(element.shape()); @@ -1429,6 +1400,12 @@ bool LiteralBase::Piece::EqualElementsInternal( bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + if (ShapeUtil::Equal(subshape(), other.subshape()) && + LayoutUtil::IsDenseArray(subshape())) { + CHECK_EQ(size_bytes(), other.size_bytes()); + return memcmp(buffer(), other.buffer(), size_bytes()) == 0; + } + std::vector multi_index; switch (subshape().element_type()) { case PRED: @@ -1481,7 +1458,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { namespace { template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, +static bool AllElementsEqualValue(absl::Span data, NativeT value) { for (int64 i = 0; i < data.size(); ++i) { if (data[i] != value) { @@ -1680,7 +1657,62 @@ bool LiteralBase::IsAllFirst() const { }); } -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsR1Iota() const { + if (!ShapeUtil::IsArray(shape())) { + return false; + } + + if (ShapeUtil::Rank(shape()) != 1) { + return false; + } + + auto is_iota_at_idx = [&](const int64 idx) { + switch (shape().element_type()) { + case U8: + return Get({idx}) == idx; + case U16: + return Get({idx}) == idx; + case U32: + return Get({idx}) == idx; + case U64: + return Get({idx}) == idx; + case S8: + return Get({idx}) == idx; + case S16: + return Get({idx}) == idx; + case S32: + return Get({idx}) == idx; + case S64: + return Get({idx}) == idx; + case F32: + return Get({idx}) == idx; + case F64: + return Get({idx}) == idx; + case F16: + return Get({idx}) == static_cast(idx); + case BF16: + return Get({idx}) == static_cast(idx); + case C64: + return Get({idx}) == complex64(idx, 0.0f); + case PRED: + return Get({idx}) == idx; + // token, opaque, tuple, etc. are all not iota. + default: + return false; + } + }; + + const int64 elements = ShapeUtil::ElementsIn(shape()); + for (int64 idx = 0; idx < elements; ++idx) { + if (!is_iota_at_idx(idx)) { + return false; + } + } + + return true; +} + +bool LiteralBase::IsZero(absl::Span indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1716,7 +1748,7 @@ namespace { template void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { + const absl::Span src) { *dest = RepeatedFieldT(src.begin(), src.end()); } @@ -1728,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); break; + case S8: + proto->set_s8s(static_cast(data().data()), + element_count()); + break; case U8: proto->set_u8s(static_cast(data().data()), element_count()); @@ -1794,7 +1830,7 @@ void* LiteralBase::Piece::untyped_data() { namespace { template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, +Status CopyFromRepeatedField(absl::Span dest, const RepeatedFieldT& src) { if (dest.size() != src.size()) { return InvalidArgument( @@ -1808,7 +1844,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. + // These conditions should have been checked in + // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); @@ -1817,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); break; + case S8: { + auto s8_data = data(); + TF_RET_CHECK(proto.s8s().size() == s8_data.size()); + std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin()); + } break; case U8: { auto u8_data = data(); TF_RET_CHECK(proto.u8s().size() == u8_data.size()); @@ -1900,7 +1942,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } -void* Literal::untyped_data(const ShapeIndex& shape_index) { +void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } @@ -1916,6 +1958,127 @@ string LiteralBase::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } +void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, + Piece* src_piece, + Piece* dest_piece) { + DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape())) + << "src_piece has shape: " + << ShapeUtil::HumanString(src_piece->subshape()) + << "dest_piece has shape: " + << ShapeUtil::HumanString(dest_piece->subshape()); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece); + + dest_piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + dest_piece->set_buffer(src_piece->buffer()); + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(dest_piece->size_bytes(), 0); + } +} + +MutableLiteralBase::~MutableLiteralBase() {} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableBorrowingLiteral& literal) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( + const MutableBorrowingLiteral& literal) { + shape_ = absl::make_unique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); + + return *this; +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableLiteralBase& literal) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal->shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + MutableBorrowingLiteral literal, const ShapeIndex& view_root) + : MutableLiteralBase() { + shape_ = absl::make_unique(literal.piece(view_root).subshape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, + const Shape& shape) + : MutableLiteralBase() { + shape_ = absl::make_unique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + CHECK(!ShapeUtil::IsTuple(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptr)); + root_piece_->set_subshape(shape_.get()); +} + +MutableBorrowingLiteral::~MutableBorrowingLiteral() { + if (root_piece_ != nullptr) { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete piece->sparse_indices(); + } + }); + delete root_piece_; + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { CHECK(ShapeUtil::IsTuple(shape)); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { @@ -1932,15 +2095,8 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } } -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); CHECK(LayoutUtil::HasLayout(*shape_)); @@ -1949,9 +2105,9 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) root_piece_.set_subshape(shape_.get()); } -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { +BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, + const Shape& shape) + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index dd67dfa8d4a556aea179bc47abfdc9a9c8872c45..3cd3541fe1596600b4f0b43e3011e1f0322ac8fe 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -25,13 +25,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -40,8 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -70,13 +70,12 @@ class LiteralBase { // Serialize to proto. LiteralProto ToProto() const; - // Returns an ArraySlice of the array for this literal for the given NativeT + // Returns a Span of the array for this literal for the given NativeT // (e.g., float). CHECKs if the subshape of the literal at the given // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type // to native type. template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + absl::Span data(const ShapeIndex& shape_index = {}) const; // Returns a const pointer to the sparse index array. Returns nullptr if the // literal is not a sparse array. @@ -100,12 +99,12 @@ class LiteralBase { // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, + NativeT Get(absl::Span multi_index, const ShapeIndex& shape_index) const; // Overloads of Get for array literals. CHECKs if the literal is not // array-shaped and dense. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + NativeT Get(absl::Span multi_index) const; // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -114,7 +113,7 @@ class LiteralBase { // As Get(), but determines the correct type and converts the value // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, + string GetAsString(absl::Span multi_index, const ShapeIndex& shape_index = {}) const; // As GetSparseElement(), but determines the correct type and converts the // value into text. @@ -122,14 +121,13 @@ class LiteralBase { const ShapeIndex& shape_index = {}) const; // As Get(), but determines the correct type and converts the value into // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; + StatusOr GetIntegralAsS64(absl::Span multi_index) const; // Returns the multi-index of the element in a sparse literal at the given // sparse element number. The sparse element number is the position with in // the sparse array's list of (index, value) pairs, and is checked against the // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( + absl::Span GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; // Returns the value of the element in a sparse literal at the given sparse @@ -150,12 +148,12 @@ class LiteralBase { // // This literal must have a dense layout. void EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const; template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; + void EachCell( + std::function indices, NativeT value)> + per_cell) const; // Returns whether every element in this literal is equal to value. // @@ -195,13 +193,20 @@ class LiteralBase { // Literal consists entirely of the first element of the literal. bool IsAllFirst() const; + // Literal consists entirely of an iota. + bool IsR1Iota() const; + // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + bool IsZero(absl::Span indices) const; // Returns the count of the elements in the array at the given shape index in // this literal. int64 element_count(const ShapeIndex& index = {}) const { + if (index.empty()) { + // Common case, avoid GetSubshape(). + return ShapeUtil::ElementsIn(shape()); + } return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); } @@ -216,31 +221,20 @@ class LiteralBase { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + StatusOr ConvertToShape(const Shape& dest_shape) const; // Converts this literal to another primitive type using a bitcast // conversion. The to and from primitive types must have the same bit // width. Returns an error if the conversion is not possible. This literal // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; + StatusOr BitcastConvert(PrimitiveType primitive_dest_type) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + StatusOr Convert(PrimitiveType primitive_dest_type) const; - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. + // Clones the underlying buffers into a new Literal. Literal Clone() const; - std::unique_ptr CloneToUnique() const; // TODO(b/67651157): The methods below which perform computation on Literals // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with @@ -258,25 +252,23 @@ class LiteralBase { // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; + Literal Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; // An overload of Relayout which changes the layout of the entire shape rather // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; + Literal Relayout(const Shape& shape_with_layout) const; // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; + StatusOr Reshape(absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; + StatusOr Broadcast(const Shape& result_shape, + absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,8 +277,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; + Literal Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -294,26 +285,26 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; + Literal Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. // This literal must be an array. template - std::unique_ptr Replicate(int64 times) const; + Literal Replicate(int64 times) const; // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). // // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique(shape), - // followed by the call to Literal::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // absl::make_unique(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. + static Literal CreateFromShape(const Shape& shape); protected: // A data structure representing a subshape at a particular ShapeIndex within @@ -324,9 +315,9 @@ class LiteralBase { // Returns the buffer holding the array data for this piece as an array // slice. This piece must be array-shaped. template - tensorflow::gtl::ArraySlice data() const; + absl::Span data() const; template - tensorflow::gtl::MutableArraySlice data(); + absl::Span data(); // Returns the buffer holding the array data for this piece as a void*. This // piece must be array-shaped. @@ -337,9 +328,9 @@ class LiteralBase { // is CHECKed against the dimension sizes of the array. This piece must be // array-shaped. template - NativeT Get(tensorflow::gtl::ArraySlice index) const; + NativeT Get(absl::Span index) const; template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); + void Set(absl::Span index, NativeT value); // Gets/sets the buffer holding the array data. char* buffer() const { return buffer_; } @@ -534,52 +525,27 @@ class LiteralBase { virtual const Piece& root_piece() const = 0; // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; + friend class MutableLiteralBase; friend class LiteralSlice; friend class BorrowingLiteral; private: template - std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; + Literal SliceInternal(const Shape& result_shape, + absl::Span start_indices) const; }; -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } + virtual ~MutableLiteralBase() = 0; - // Returns a MutableArraySlice view of the array for this literal for the + // Returns a Span view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the // given ShapeIndex is not array. See primitive_util.h for the mapping from // XLA type to native type. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + absl::Span data(const ShapeIndex& shape_index = {}); // Unhide const method from parent class. using LiteralBase::data; @@ -587,6 +553,10 @@ class Literal : public LiteralBase { // is not a sparse array. SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex // is not array. @@ -602,8 +572,7 @@ class Literal : public LiteralBase { // are populated. template void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); + absl::Span values, bool sort = true); // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted @@ -613,21 +582,6 @@ class Literal : public LiteralBase { const ShapeIndex& dest_shape_index = {}, const ShapeIndex& src_shape_index = {}); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -639,39 +593,38 @@ class Literal : public LiteralBase { // corresponding base indices being 0. // This literal and 'src_literal' must be arrays. Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Copies one element from src_literal[src_index] to (*this)[dest_index]. Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); + absl::Span src_index, + absl::Span dest_index); // Sets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); + void Set(absl::Span multi_index, const ShapeIndex& shape_index, + NativeT value); // Overloads of Set for array literals. CHECKs if the literal is not // array-shaped and dense. template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + void Set(absl::Span multi_index, NativeT value); // Appends the given element to the literal. If the elements are not appended // in sorted order, then SortSparseElements should be called before calling // other methods. This literal must have a sparse layout. template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); + void AppendSparseElement(absl::Span multi_index, NativeT value, + const ShapeIndex& shape_index = {}); // Sorts the elements in a sparse array. void SortSparseElements(const ShapeIndex& shape_index = {}); // As Set(), but truncates `value` to the literal element type before storing. // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + Status SetIntegralAsS64(absl::Span multi_index, int64 value); // Populate this literal with the given values. Examples: // @@ -686,7 +639,7 @@ class Literal : public LiteralBase { // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 // array of S32. template - void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(absl::Span values); void PopulateR1(const tensorflow::core::Bitmap& values); template void PopulateR2(std::initializer_list> values); @@ -703,7 +656,7 @@ class Literal : public LiteralBase { // in this literal object. // // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // NativeT(absl::Span indexes) or compatible. // // This literal must have a dense layout. template @@ -723,19 +676,12 @@ class Literal : public LiteralBase { // moved into the tuple elements of a new tuple-shaped Literal which is // returned. Upon return, each of the Literals in 'elements' is set to a nil // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); + static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); - - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); + static StatusOr CreateFromProto(const LiteralProto& proto); + protected: // Returns the piece at the given ShapeIndex. Piece& piece(const ShapeIndex& shape_index) { return const_cast(LiteralBase::piece(shape_index)); @@ -747,20 +693,20 @@ class Literal : public LiteralBase { // arguments one by one. template Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. struct StrideConfig { StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // The dimensions of the stride operation. Essentially every dimension // will be iterated from base[i] to base[i]+dimensions[i], in step[i] // steps. - tensorflow::gtl::ArraySlice dimensions; + absl::Span dimensions; DimensionVector base; DimensionVector step; int64 minor_dimension = 0; @@ -783,12 +729,83 @@ class Literal : public LiteralBase { template Status PopulateInternal(const FnType& generator, bool parallel); + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + private: // Deallocate the buffers held by this literal. void DeallocateBuffers(); - friend class LiteralBase; + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); }; -std::ostream& operator<<(std::ostream& out, const Literal& literal); // A read-only view of a Literal. A LiteralSlice contains pointers to shape and // literal buffers always owned by others. @@ -818,7 +835,7 @@ class BorrowingLiteral : public LiteralBase { // This constructor is only used for array shapes. BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape); // TODO(b/79707221): adding constructors for nested tuples as well. @@ -831,48 +848,47 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. std::unique_ptr shape_; }; template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) +absl::Span LiteralBase::Piece::data() const { + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) +absl::Span LiteralBase::Piece::data() { + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { +NativeT LiteralBase::Piece::Get(absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)]; } template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, +void LiteralBase::Piece::Set(absl::Span multi_index, NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -880,38 +896,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, } template -tensorflow::gtl::ArraySlice LiteralBase::data( +absl::Span LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } template -tensorflow::gtl::MutableArraySlice Literal::data( - const ShapeIndex& shape_index) { +absl::Span MutableLiteralBase::data(const ShapeIndex& shape_index) { return piece(shape_index).data(); } template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, +inline NativeT LiteralBase::Get(absl::Span multi_index, const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { +inline NativeT LiteralBase::Get(absl::Span multi_index) const { return root_piece().Get(multi_index); } template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + const ShapeIndex& shape_index, + NativeT value) { return piece(shape_index).Set(multi_index, value); } template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + NativeT value) { return root_piece().Set(multi_index, value); } @@ -929,8 +944,8 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, } template -void Literal::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, +void MutableLiteralBase::AppendSparseElement( + absl::Span multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); @@ -946,8 +961,7 @@ void Literal::AppendSparseElement( template void LiteralBase::EachCell( - std::function indices, - NativeT value)> + std::function indices, NativeT value)> per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -955,11 +969,11 @@ void LiteralBase::EachCell( std::vector indices(ShapeUtil::Rank(shape()), 0); do { per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); + } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); } template -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { +inline void MutableLiteralBase::PopulateR1(absl::Span values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -971,7 +985,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { } template -void Literal::PopulateR2( +void MutableLiteralBase::PopulateR2( std::initializer_list> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 2); @@ -996,7 +1010,7 @@ void Literal::PopulateR2( } template -void Literal::PopulateFromArray(const Array& values) { +void MutableLiteralBase::PopulateFromArray(const Array& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1004,29 +1018,30 @@ void Literal::PopulateFromArray(const Array& values) { for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); + values.Each([this](absl::Span indices, NativeT value) { + this->Set(indices, value); + }); } template -void Literal::PopulateR2FromArray2D(const Array2D& values) { +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D& values) { PopulateFromArray(values); } template -void Literal::PopulateR3FromArray3D(const Array3D& values) { +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D& values) { PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4D(const Array4D& values) { +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { PopulateFromArray(values); } template -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort) { +void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, + absl::Span values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1036,7 +1051,7 @@ void Literal::PopulateSparse(SparseIndexArray indices, CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices + // Piece::data() returns a Span of size equal to the number of indices // in the SparseIndexArray. So there is no need to adjust the size of the data // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); @@ -1049,20 +1064,21 @@ void Literal::PopulateSparse(SparseIndexArray indices, } template -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); + absl::Span literal_data = data(); if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + auto init_function = [&](absl::Span indexes) { DimensionVector minor_scan_indexes(rank, 0); const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); @@ -1080,7 +1096,7 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) { ShapeUtil::ForEachIndex( this_shape, stride_config.base, stride_config.dimensions, stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { + [&init_function](absl::Span indexes) { init_function(indexes); return true; }); @@ -1092,17 +1108,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) { return Status::OK(); } template -Status Literal::Populate(const FnType& generator) { +Status MutableLiteralBase::Populate(const FnType& generator) { return PopulateInternal(generator, /*parallel=*/false); } template -Status Literal::PopulateParallel(const FnType& generator) { +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { return PopulateInternal(generator, /*parallel=*/true); } template -void Literal::PopulateWithValue(NativeT value) { +void MutableLiteralBase::PopulateWithValue(NativeT value) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1112,27 +1128,26 @@ void Literal::PopulateWithValue(NativeT value) { } template -std::unique_ptr LiteralBase::Replicate(int64 times) const { +Literal LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); + Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal.shape()); if (elements == 0) { return literal; } DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; + absl::Span input_indices = output_indices; input_indices.remove_prefix(1); bool done = false; while (!done) { const auto element = Get(input_indices); - literal->Set(output_indices, element); + literal.Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 94993cc87443ba8c22fd7c2eacfc8756d3f48edc..3d8725ed7051cafc97987f25a96004fa876dfdd3 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" -using tensorflow::strings::Appendf; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; namespace xla { namespace literal_comparison { @@ -38,7 +38,8 @@ namespace { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); auto lhs_double = static_cast(lhs); @@ -46,9 +47,10 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, - StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + "was requested: %s=%g=%a vs %s=%g=%a at array index %s", + StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, + StrCat(absl::Hex(urhs)), rhs_double, rhs_double, + LiteralUtil::MultiIndexAsString(multi_index)); } return Status::OK(); } @@ -57,39 +59,47 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { // bitwise helper above (this is the un-specialized fallback, to just use the // default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs) { +Status CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { if (lhs == rhs) { return Status::OK(); } - return InvalidArgument("Expected equality of these values:\n %s\n %s", - StrCat(lhs).c_str(), StrCat(rhs).c_str()); + return InvalidArgument( + "first mismatch at array index %s:\n expected value: %s\n actual " + "value: %s", + LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); +Status CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); if (!res.ok()) { return res; } - return CompareEqual(lhs.imag(), rhs.imag()); + return CompareEqual(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -97,18 +107,18 @@ Status CompareEqual(complex64 lhs, complex64 rhs) { // elements are equal. template Status Equal(LiteralSlice expected, LiteralSlice actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { + absl::Span multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value); + return CompareEqual(expected_value, actual_value, multi_index); } Status result; for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index[dimension] = i; - result.Update(Equal(expected, actual, multi_index, dimension + 1)); + TF_RETURN_IF_ERROR( + Equal(expected, actual, multi_index, dimension + 1)); } return result; } @@ -152,15 +162,26 @@ bool NanMismatch(half expected, half actual, bool relaxed_nans) { static_cast(actual), relaxed_nans); } +// Returns whether the given value is infinity. +template +bool IsInf(NativeT val) { + return std::isinf(val); +} + +template <> +bool IsInf(half val) { + return std::isinf(static_cast(val)); +} + // Converts the given floating-point value to a string. template string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); + return absl::StrFormat("%8.4g", static_cast(value)); } template <> string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } // Returns the absolute value of the given floating point value. This function @@ -215,13 +236,12 @@ class NearComparator { } string ToString(const Shape& shape) const { - return Printf( + return absl::StrFormat( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + FpValueToString(actual), FpValueToString(expected), LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), + linear_index)), rel_error, abs_error); } }; @@ -240,17 +260,12 @@ class NearComparator { // Runs the comparison between expected and actual literals. Status Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, ToStringTruncated(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, ToStringTruncated(actual_)); - // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); if (!ShapeUtil::IsArray(expected_.shape())) { return InvalidArgument("Expected array shape; got %s.", - ShapeUtil::HumanString(expected_.shape()).c_str()); + ShapeUtil::HumanString(expected_.shape())); } mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); @@ -263,7 +278,7 @@ class NearComparator { } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { miscompare_callback_(expected_, actual_, mismatches_); } - return InvalidArgument("%s", ErrorMessage().c_str()); + return InvalidArgument("%s", ErrorMessage()); } // Insert the given absolute value into the absolute value bucket vector. The @@ -288,8 +303,7 @@ class NearComparator { } // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { + void UpdateErrorBucket(float error, absl::Span error_buckets) { CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < error_buckets.size(); ++i) { if (error >= kErrorBucketBounds[i]) { @@ -300,12 +314,13 @@ class NearComparator { // Compares the two given elements from the expected and actual literals at // the given literal_index and keeps track of various mismatch statistics. - void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + template + void CompareValues(T expected, T actual, int64 linear_index) { const bool is_nan_mismatch = NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (actual == expected) { + if (CompareEqual(expected, actual, {linear_index}).ok()) { abs_error = 0; rel_error = 0; } else if (is_nan_mismatch) { @@ -316,6 +331,12 @@ class NearComparator { // weak ordering requirement of std containers. abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); + } else if (IsInf(expected) || IsInf(actual)) { + // If either the expected or actual value is infinity but not both, + // then both absolute and relative error are regarded as inifity. + CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); rel_error = abs_error / FpAbsoluteValue(expected); @@ -329,11 +350,11 @@ class NearComparator { // bound is exceeded and vice versa. if (is_abs_mismatch) { num_abs_mismatches_++; - UpdateErrorBucket(rel_error, &rel_error_buckets_); + UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_)); } if (is_rel_mismatch) { num_rel_mismatches_++; - UpdateErrorBucket(abs_error, &abs_error_buckets_); + UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_)); } UpdateAbsValueBucket(actual, is_mismatch); @@ -358,15 +379,36 @@ class NearComparator { mismatches_.data()[linear_index] = true; } + // For complex64 types, we compare real and imaginary parts individually. + void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. if (LayoutUtil::Equal(actual_.shape().layout(), expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); + absl::Span expected_data = expected_.data(); + absl::Span actual_data = actual_.data(); const int64 len = expected_data.size(); for (int64 i = 0; i < len; ++i) { CompareValues(expected_data[i], actual_data[i], i); @@ -402,23 +444,23 @@ class NearComparator { auto percent_string = [](float a, float b) { float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); + return absl::StrFormat("%0.4f%%", pct); }; - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + StrAppendFormat( + &out, + "\nMismatch count %d (%s) in shape %s (%d elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, percent_string(num_mismatches_, element_count), + ShapeUtil::HumanString(actual_.shape()), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); if (num_nan_mismatches_ > 0) { StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); } - Appendf(&out, "Top relative error mismatches:\n"); + StrAppendFormat(&out, "Top relative error mismatches:\n"); for (auto it = top_rel_mismatches_.rbegin(); it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + StrAppend(&out, " ", it->ToString(actual_.shape()), "\n"); } if (!detailed_message_) { @@ -430,36 +472,37 @@ class NearComparator { for (int i = 0; i < abs_value_buckets_.size(); ++i) { const int64 bucket_size = abs_value_buckets_[i].first; const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); + string mismatch_str = + bucket_mismatches > 0 + ? absl::StrFormat(", mismatches %d", bucket_mismatches) + : ""; + StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count), + mismatch_str); } auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { + absl::Span buckets) { StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); + StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total)); CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); + StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total)); } }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count)); print_accum_buckets( "Relative error breakdown of elements exceeding abs error bound", num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count)); print_accum_buckets( "Absolute error breakdown of elements exceeding rel error bound", num_rel_mismatches_, abs_error_buckets_); @@ -528,6 +571,63 @@ constexpr std::array NearComparator::kAbsValueBucketBounds; template constexpr std::array NearComparator::kErrorBucketBounds; +Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + auto index = absl::MakeSpan(multi_index); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, index, 0); + break; + case U8: + result = Equal(expected, actual, index, 0); + break; + case S32: + result = Equal(expected, actual, index, 0); + break; + case S64: + result = Equal(expected, actual, index, 0); + break; + case U32: + result = Equal(expected, actual, index, 0); + break; + case U64: + result = Equal(expected, actual, index, 0); + break; + case BF16: + result = Equal(expected, actual, index, 0); + break; + case F16: + result = Equal(expected, actual, index, 0); + break; + case F32: + result = Equal(expected, actual, index, 0); + break; + case F64: + result = Equal(expected, actual, index, 0); + break; + case C64: + result = Equal(expected, actual, index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update(EqualHelper(LiteralSlice(expected, {i}), + LiteralSlice(actual, {i}))); + } + break; + } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); + default: + LOG(FATAL) << "Unsupported primitive type: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + return result; +} + // Helper function for comparing two literals for nearness. Handles tuple-shapes // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. @@ -544,17 +644,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, const auto actual_element = LiteralSlice(actual, {i}); ShapeIndex element_index = shape_index; element_index.push_back(i); - Status res = + Status element_result = NearHelper(expected_element, actual_element, error, detailed_message, miscompare_callback, element_index); - if (!res.ok()) { - string err_message = Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), - res.error_message().c_str()); + if (!element_result.ok()) { + element_result = InvalidArgument("Array at shape index %s, %s", + element_index.ToString(), + element_result.error_message()); if (return_status.ok()) { - return_status = res; + return_status = element_result; } else { - return_status = AppendStatus(return_status, res.error_message()); + return_status = + AppendStatus(return_status, element_result.error_message()); } } } @@ -562,10 +663,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, // Emit a top-level error message containing the top-level shape in case // of mismatch. int64 total_elements = RecursiveElementCount(actual.shape()); - return_status = InvalidArgument( - "\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, - return_status.error_message().c_str()); + return_status = + InvalidArgument("\nMismatches in shape %s (%d elements):\n%s", + ShapeUtil::HumanString(actual.shape()), + total_elements, return_status.error_message()); } return return_status; } @@ -600,8 +701,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } } - // Non-floating point literal. - return literal_comparison::Equal(expected, actual); + // Non-floating point, non-tuple literal. + return EqualHelper(expected, actual); } } // namespace @@ -609,14 +710,14 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.element_type() != actual.element_type()) { return InvalidArgument("element type mismatch, want: %s got %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (ShapeUtil::IsTuple(expected)) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( - "want tuple element count: %lld got tuple element count: %lld", + "want tuple element count: %d got tuple element count: %d", ShapeUtil::TupleElementCount(expected), ShapeUtil::TupleElementCount(actual)); } @@ -630,14 +731,13 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (expected.element_type() != actual.element_type()) { - return InvalidArgument( - "mismatch in primitive type %s vs %s", - PrimitiveType_Name(expected.element_type()).c_str(), - PrimitiveType_Name(actual.element_type()).c_str()); + return InvalidArgument("mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()), + PrimitiveType_Name(actual.element_type())); } if (expected.dimensions_size() != actual.dimensions_size()) { return InvalidArgument("want dimensions_size %d got dimensions_size %d", @@ -648,8 +748,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.dimensions(i) != actual.dimensions(i)) { return InvalidArgument( "mismatch in dimension #%d expected: %s actual: %s", i, - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } } } @@ -657,83 +756,43 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return Status::OK(); } +namespace { + +// If result is an error, extend the error message with the expected and actual +// literals. +Status EmitLiteralsInErrorMessage(const Status& result, + const LiteralSlice& expected, + const LiteralSlice& actual) { + if (result.ok()) { + return result; + } + return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s", + result.error_message(), ToStringTruncated(expected), + ToStringTruncated(actual)); +} + +} // namespace + Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { VLOG(1) << "expected:"; XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - - TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - std::vector multi_index(expected.shape().dimensions_size(), 0); - Status result; - switch (expected.shape().element_type()) { - case PRED: - result = Equal(expected, actual, &multi_index, 0); - break; - case U8: - result = Equal(expected, actual, &multi_index, 0); - break; - case S32: - result = Equal(expected, actual, &multi_index, 0); - break; - case S64: - result = Equal(expected, actual, &multi_index, 0); - break; - case U32: - result = Equal(expected, actual, &multi_index, 0); - break; - case U64: - result = Equal(expected, actual, &multi_index, 0); - break; - case BF16: - result = Equal(expected, actual, &multi_index, 0); - break; - case F16: - result = Equal(expected, actual, &multi_index, 0); - break; - case F32: - result = Equal(expected, actual, &multi_index, 0); - break; - case F64: - result = Equal(expected, actual, &multi_index, 0); - break; - case C64: - result = Equal(expected, actual, &multi_index, 0); - break; - case TUPLE: { - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - result.Update( - Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); - } - break; - } - case TOKEN: - // Tokens have no on-device representation and are trivially equal. - return Status::OK(); - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - - if (result.ok()) { - return Status::OK(); - } - - return AppendStatus(result, - tensorflow::strings::Printf( - "\nat index: %s\nexpected: %s\nactual: %s", - LiteralUtil::MultiIndexAsString(multi_index).c_str(), - ToStringTruncated(expected).c_str(), - ToStringTruncated(actual).c_str())); + Status result = EqualHelper(expected, actual); + return EmitLiteralsInErrorMessage(result, expected, actual); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error, bool detailed_message, const MiscompareCallback& miscompare_callback) { - return NearHelper(expected, actual, error, detailed_message, - miscompare_callback, - /*shape_index=*/{}); + VLOG(1) << "Expected literal:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "Actual literal:"; + XLA_VLOG_LINES(1, actual.ToString()); + Status result = + NearHelper(expected, actual, error, detailed_message, miscompare_callback, + /*shape_index=*/{}); + return EmitLiteralsInErrorMessage(result, expected, actual); } string ToStringTruncated(const LiteralSlice& literal) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index e8f919950f0efc8b508f7ad4aee5233176bc0abd..7ad287c8973367fb04583e6911ff75e76bdf5f1e 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -33,7 +36,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -90,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test { Layout layout_r3_dim0minor_; Layout layout_r4_dim0major_; Layout layout_r4_dim0minor_; - std::unique_ptr literal_r4_2x2x3x3_dim0major_; - std::unique_ptr literal_r4_2x2x3x3_dim0minor_; + Literal literal_r4_2x2x3x3_dim0major_; + Literal literal_r4_2x2x3x3_dim0minor_; }; TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit.ToString()); - // 3.14 will be truncated to 3.125 in bfloat16 format. + // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - ASSERT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -141,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -155,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -169,7 +171,7 @@ f32[2,2] { { 3, 4 } } ))"; - ASSERT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple.ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -185,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -195,7 +197,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { { 9, 10 }, { 11, 12 } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, CreateSparse) { @@ -218,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) { }; std::vector expected_values = {8, 9, 7, 10}; - EXPECT_EQ(literal->sparse_indices()->data(), - ArraySlice(expected_indices.data(), - expected_indices.num_elements())); - EXPECT_EQ(literal->data(), ArraySlice(expected_values)); + EXPECT_EQ(literal.sparse_indices()->data(), + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal.data(), absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -232,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { { /*i0=0*/ { /*i1=0*/ @@ -248,13 +250,13 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = literal_r4_2x2x3x3_dim0major_->ToString(); + string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { { /*i0=0*/ { /*i1=0*/ @@ -281,7 +283,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, EachCellR2F32) { @@ -292,8 +294,8 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { }); // clang-format on std::vector> seen; - literal->EachCellAsString( - [&seen](ArraySlice indices, const string& value) { + literal.EachCellAsString( + [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -308,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) { auto f32_42 = LiteralUtil::CreateR0(42.0); auto f32_42_clone = LiteralUtil::CreateR0(42.0); - EXPECT_EQ(*f32_42, *f32_42); - EXPECT_EQ(*f32_42, *f32_42_clone); + EXPECT_EQ(f32_42, f32_42); + EXPECT_EQ(f32_42, f32_42_clone); auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*f32_42, *f32_123); + EXPECT_NE(f32_42, f32_123); auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_NE(*f32_42, *f64_42); + EXPECT_NE(f32_42, f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { @@ -328,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(*matrix, *matrix); - EXPECT_EQ(*matrix, *matrix_clone); - EXPECT_NE(*matrix, *matrix_different); - EXPECT_NE(*matrix, *vector_literal); - EXPECT_NE(*matrix, *scalar); - EXPECT_NE(*matrix, nil); + EXPECT_EQ(matrix, matrix); + EXPECT_EQ(matrix, matrix_clone); + EXPECT_NE(matrix, matrix_different); + EXPECT_NE(matrix, vector_literal); + EXPECT_NE(matrix, scalar); + EXPECT_NE(matrix, nil); EXPECT_EQ(nil, nil); } @@ -342,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) { auto token1 = LiteralUtil::CreateToken(); auto scalar = LiteralUtil::CreateR0(1.0); - EXPECT_EQ(*token0, *token1); - EXPECT_NE(*token0, *scalar); + EXPECT_EQ(token0, token1); + EXPECT_NE(token0, scalar); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), - *LiteralUtil::MakeTuple({token0.get()})); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0}), + LiteralUtil::MakeTuple({&token0})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&token1, &scalar})); + EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&scalar, &token1})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); - colmajor->Set({0, 0}, 1.0); - colmajor->Set({0, 1}, 2.0); - colmajor->Set({1, 0}, 3.0); - colmajor->Set({1, 1}, 4.0); + Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + colmajor.Set({0, 0}, 1.0); + colmajor.Set({0, 1}, 2.0); + colmajor.Set({1, 0}, 3.0); + colmajor.Set({1, 1}, 4.0); - auto rowmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); - rowmajor->Set({0, 0}, 1.0); - rowmajor->Set({0, 1}, 2.0); - rowmajor->Set({1, 0}, 3.0); - rowmajor->Set({1, 1}, 4.0); + Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + rowmajor.Set({0, 0}, 1.0); + rowmajor.Set({0, 1}, 2.0); + rowmajor.Set({1, 0}, 3.0); + rowmajor.Set({1, 1}, 4.0); - EXPECT_EQ(*rowmajor, *colmajor); + EXPECT_EQ(rowmajor, colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_EQ(*tuple1, *tuple2); + auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix}); + EXPECT_EQ(tuple1, tuple2); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_NE(*tuple1, *reversed_tuple); + auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar}); + EXPECT_NE(tuple1, reversed_tuple); // Tuple with different value. auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_NE(*tuple1, *different_tuple); + auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix}); + EXPECT_NE(tuple1, different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { @@ -403,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) { // tuple, the other is a clone of the element in the original tuple. auto vector_clone = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - EXPECT_EQ(*vector, *vector_clone); + EXPECT_EQ(vector, vector_clone); auto vector_reversed = LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); - EXPECT_NE(*vector, *vector_reversed); + EXPECT_NE(vector, vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto tuple = LiteralUtil::MakeTuple({&element1, &element1}); // Tuples should always return false for IsAll. - EXPECT_FALSE(tuple->IsAll(0)); - EXPECT_FALSE(tuple->IsAll(1)); + EXPECT_FALSE(tuple.IsAll(0)); + EXPECT_FALSE(tuple.IsAll(1)); } // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto scalar = LiteralUtil::CreateR0(0.0); auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_EQ(*tuple, *x); + auto x = Literal::CreateFromShape(tuple.shape()); + EXPECT_EQ(tuple, x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false).IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255).IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0).IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001).IsAll(42)); - EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100}).IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001}).IsAll(100)); - EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}}).IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}}).IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}}).IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) - ->IsAll(-1)); + .IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) - ->IsAllFloat(.5)); + .IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE( - LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = LiteralUtil::CreateR0(0.0f); auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(scalar_zero->IsZero({})); - EXPECT_FALSE(scalar_one->IsZero({})); + EXPECT_TRUE(scalar_zero.IsZero({})); + EXPECT_FALSE(scalar_one.IsZero({})); auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(array->IsZero({0, 1})); - EXPECT_TRUE(array->IsZero({0, 2})); - EXPECT_TRUE(array->IsZero({1, 1})); - EXPECT_FALSE(array->IsZero({1, 2})); + EXPECT_FALSE(array.IsZero({0, 1})); + EXPECT_TRUE(array.IsZero({0, 2})); + EXPECT_TRUE(array.IsZero({1, 1})); + EXPECT_FALSE(array.IsZero({1, 2})); auto complex_zero = LiteralUtil::CreateR0(0.0f); auto complex_nonzero = LiteralUtil::CreateR0(0.5f); - EXPECT_TRUE(complex_zero->IsZero({})); - EXPECT_FALSE(complex_nonzero->IsZero({})); + EXPECT_TRUE(complex_zero.IsZero({})); + EXPECT_FALSE(complex_nonzero.IsZero({})); } template @@ -574,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = data->Relayout(layout01); - EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_EQ(*data, *data01); + auto data01 = data.Relayout(layout01); + EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01)); + EXPECT_EQ(data, data01); - auto data10 = data->Relayout(layout10); - EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_EQ(*data, *data10); + auto data10 = data.Relayout(layout10); + EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10)); + EXPECT_EQ(data, data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -604,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -624,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Transpose(/*permutation=*/{}); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -644,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](ArraySlice indices, float value) { - EXPECT_EQ(value, original->Get( + reshape.EachCell([&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get( {indices[2], indices[3], indices[0], indices[1]})); }); } @@ -656,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. auto dim0minor_relaid_to_dim0major = - literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); + literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = - literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); + literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->element_count(), 6); - EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_EQ(mat_dim0minor.element_count(), 6); + EXPECT_THAT(mat_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->data(), + auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_); + EXPECT_THAT(relaid_mat_to_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->element_count(), 6); - EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_EQ(mat_dim0major.element_count(), 6); + EXPECT_THAT(mat_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->data(), + auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_); + EXPECT_THAT(relaid_mat_to_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -705,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->element_count(), 12); + EXPECT_EQ(lit_dim0minor.element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->data(), + EXPECT_THAT(lit_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->data(), + EXPECT_THAT(relaid_lit_to_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->element_count(), 12); - EXPECT_THAT(lit_dim0major->data(), + EXPECT_EQ(lit_dim0major.element_count(), 12); + EXPECT_THAT(lit_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->data(), + auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_); + EXPECT_THAT(relaid_lit_to_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { auto input = LiteralUtil::CreateR0(1); - auto result = input->Slice({}, {}); - EXPECT_EQ(*input, *result); + auto result = input.Slice({}, {}); + EXPECT_EQ(input, result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = input->Slice({3}, {4}); + auto result = input.Slice({3}, {4}); auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR2U32) { auto input_3x4 = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto result = input_3x4.Slice({0, 2}, {2, 4}); auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_EQ(*input_2x3x2, *result); + auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_EQ(input_2x3x2, result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = LiteralUtil::CreateR1({77}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { @@ -783,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { @@ -791,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { bfloat16 h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { @@ -799,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { bfloat16 h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { @@ -807,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { bfloat16 h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { @@ -836,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { output.PopulateWithValue({4, 2}); auto expected = LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -844,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -852,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -860,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = input->Replicate(3); + auto output = input.Replicate(3); auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_EQ(*output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, CopySliceFrom) { @@ -886,35 +884,35 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](ArraySlice indexes) { - source->Set(indexes, ++seqnr); + auto init_proc = [&](absl::Span indexes) { + source.Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step, init_proc); auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](ArraySlice indexes) { + auto check_proc = [&](absl::Span indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = blank->Get(blank_indexes); - matched = (bval != 0 && bval == source->Get(source_indexes)); + auto bval = blank.Get(blank_indexes); + matched = (bval != 0 && bval == source.Get(source_indexes)); return matched; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); } @@ -923,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = LiteralUtil::CreateR0(0); auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(zero->CopyFrom(*nine)); - EXPECT_EQ(*zero, *nine); + TF_EXPECT_OK(zero.CopyFrom(nine)); + EXPECT_EQ(zero, nine); auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); - EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); - EXPECT_EQ(vect->Get({4}), 17); + TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {})); + EXPECT_EQ(zero.Get({}), 17); + TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {})); + EXPECT_EQ(vect.Get({4}), 17); } TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { @@ -943,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); - EXPECT_EQ(*nine, *const_nine); + TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0})); + EXPECT_EQ(nine, const_nine); } { // Copy 0 element to destination with zero elements. - const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); - EXPECT_EQ(*empty, *const_empty); + TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0})); + EXPECT_EQ(empty, const_empty); } } @@ -967,76 +965,77 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { TEST_F(LiteralUtilTest, CopyFromArrays) { auto scalar_42 = LiteralUtil::CreateR0(42.0); auto scalar_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*scalar_42, *scalar_123); - TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*scalar_42, *scalar_123); - EXPECT_EQ(scalar_42->Get({}), 123.0f); + EXPECT_NE(scalar_42, scalar_123); + TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(scalar_42, scalar_123); + EXPECT_EQ(scalar_42.Get({}), 123.0f); auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); - EXPECT_NE(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); - TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); + EXPECT_NE(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 5.0f); } TEST_F(LiteralUtilTest, CopyFromTuples) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {matrix.get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get()}); + Literal inner_elements[] = {LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0})}; + Literal inner_tuple = LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}); + Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(-5).get(), - LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); + Literal int32_minus5 = LiteralUtil::CreateR0(-5); + Literal double_2_4 = LiteralUtil::CreateR1({2.0, 4.0}); + Literal tuple = + LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal}); - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 44.0); // Overwrite the inner tuple element of nested_tuple with the contents of // 'tuple'. - TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{})); + TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), - LiteralUtil::CreateR0(4).get()}); + Literal elements[] = {LiteralUtil::CreateR0(-2), + LiteralUtil::CreateR0(4)}; + Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), 4); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), 4); // Copy from one element to the other. - TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{0})); + TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), -2); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), -2); } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto vector = LiteralUtil::CreateR1({5.0, 7.0}); - Status status = matrix->CopyFrom(*vector); + Status status = matrix.CopyFrom(vector); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); } @@ -1044,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); - Literal* l1 = m1.get(); - const char* d1 = reinterpret_cast(l1->data().data()); + Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + const char* d1 = reinterpret_cast(m1.data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -1059,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l2 = m2.get(); - const char* d2 = reinterpret_cast(l2->data().data()); + const char* d2 = reinterpret_cast(m2.data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -1089,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + Literal literal(shape); + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->Populate(generator)); + TF_EXPECT_OK(literal.Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { - auto value = literal->Get(indexes); + auto check_function = [&](absl::Span indexes) { + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1131,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + Literal literal(shape); + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->PopulateParallel(generator)); + TF_EXPECT_OK(literal.PopulateParallel(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { - auto value = literal->Get(indexes); + auto check_function = [&](absl::Span indexes) { + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1168,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->Convert(U32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32)); - EXPECT_EQ(*expected, *converted); + EXPECT_EQ(expected, converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -1243,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); // clang-format on - std::unique_ptr conv; + Literal conv; - conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u32); + conv = s8.Convert(U32).ConsumeValueOrDie(); + EXPECT_EQ(conv, u32); - conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = s8.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u64); + conv = s8.Convert(U64).ConsumeValueOrDie(); + EXPECT_EQ(conv, u64); - conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s64); + conv = s8.Convert(S64).ConsumeValueOrDie(); + EXPECT_EQ(conv, s64); - conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *pred); + conv = s8.Convert(PRED).ConsumeValueOrDie(); + EXPECT_EQ(conv, pred); - conv = bf16->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = bf16.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = bf16->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = bf16.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *int32_pred); + conv = pred.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, int32_pred); - conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f32.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f64.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = s32.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f64.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = s32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = u32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = s32.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - conv = f16->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = f16.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(S32).status().code(), + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1315,16 +1307,15 @@ TEST_F(LiteralUtilTest, BitcastConvert) { tensorflow::bit_cast(100.f), 0xbeef}); auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->BitcastConvert(F32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); - Status status = literal->BitcastConvert(F64).status(); + Status status = literal.BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), - "bit widths are different")); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "bit widths are different")); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1339,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { p.add_preds((i % 2) == (len % 2)); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - ASSERT_EQ(len, literal->data().size()); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal.data().size()); int i = 0; - for (bool value : literal->data()) { + for (bool value : literal.data()) { EXPECT_EQ((i % 2) == (len % 2), value); ++i; } @@ -1356,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h2(2.0f); auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l = m.get(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->data().size()); + EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape())); + EXPECT_EQ(4, m.data().size()); - LiteralProto p = l->ToProto(); + LiteralProto p = m.ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); @@ -1387,56 +1376,53 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - auto r = literal->data(); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); ASSERT_EQ(4, r.size()); - ASSERT_EQ(h1, r[0]); - ASSERT_EQ(h2, r[1]); - ASSERT_EQ(h2, r[2]); - ASSERT_EQ(h1, r[3]); + EXPECT_EQ(h1, r[0]); + EXPECT_EQ(h2, r[1]); + EXPECT_EQ(h2, r[2]); + EXPECT_EQ(h1, r[3]); } TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); - EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); - EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(scalar, {}), scalar); + EXPECT_EQ(LiteralSlice(matrix, {}), matrix); + EXPECT_EQ(LiteralSlice(tuple, {}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple); EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(tuple, {0}), scalar); + EXPECT_EQ(LiteralSlice(tuple, {1}), matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix); + EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar); } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralSlice(*nested_tuple); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 1.0f); + const auto nested_tuple_view = LiteralSlice(nested_tuple); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); - nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 555.0f); + nested_tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 555.0f); @@ -1445,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); - const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1495,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { } TEST_F(LiteralUtilTest, LiteralMove) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - Literal literal(std::move(*matrix)); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(matrix)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1509,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get(), - &nil_literal}); - - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); - std::vector elements = nested_tuple->DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + Literal inner_elements[] = { + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0}), + }; + Literal tuple_elements[] = { + LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}), + }; + Literal nested_tuple = LiteralUtil::MakeTuple( + {&tuple_elements[0], &tuple_elements[1], &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + std::vector elements = nested_tuple.DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1550,15 +1539,15 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); - elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); - elements.push_back(std::move(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get()}) - - )); - - Literal literal = Literal::MoveIntoTuple(&elements); + elements.push_back(LiteralUtil::CreateR0(1.0)); + elements.push_back(LiteralUtil::CreateR1({4, 8})); + std::vector inner_elements; + inner_elements.push_back(LiteralUtil::CreateR0(42)); + inner_elements.push_back(LiteralUtil::CreateR1({23.0, 44.0})); + elements.push_back( + LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); + + Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); @@ -1577,16 +1566,15 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); - ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); + EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } TEST_F(LiteralUtilTest, LiteralMoveAssignment) { Literal literal; EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - literal = std::move(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(matrix); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1597,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { } TEST_F(LiteralUtilTest, LiteralSliceCopy) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralSlice(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralSlice(matrix); LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); @@ -1609,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42.0).get(), - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); - tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); - - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), - 3.0); - tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + Literal elements[] = { + LiteralUtil::CreateR0(42.0), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + }; + auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0); + tuple.Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), -4.0); } TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { // Literals constructed using CreateFromShape should be zero initialized. - std::unique_ptr scalar_f32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); - EXPECT_EQ(scalar_f32->Get({}), 0.0); - EXPECT_TRUE(scalar_f32->IsAll(0)); - - std::unique_ptr vector_s32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); - EXPECT_EQ(vector_s32->Get({0}), 0); - EXPECT_EQ(vector_s32->Get({1}), 0); - EXPECT_EQ(vector_s32->Get({2}), 0); - EXPECT_TRUE(vector_s32->IsAll(0)); - - std::unique_ptr tuple = - Literal::CreateFromShape(ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); - - EXPECT_EQ(tuple->Get({}, {0}), 0.0); - EXPECT_EQ(tuple->Get({0}, {1}), false); - EXPECT_EQ(tuple->Get({1}, {1}), false); - EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); + Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32.Get({}), 0.0); + EXPECT_TRUE(scalar_f32.IsAll(0)); + + Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32.Get({0}), 0); + EXPECT_EQ(vector_s32.Get({1}), 0); + EXPECT_EQ(vector_s32.Get({2}), 0); + EXPECT_TRUE(vector_s32.IsAll(0)); + + Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple.Get({}, {0}), 0.0); + EXPECT_EQ(tuple.Get({0}, {1}), false); + EXPECT_EQ(tuple.Get({1}, {1}), false); + EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1655,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto one_f32 = LiteralUtil::CreateR0(1.0); auto two_f32 = LiteralUtil::CreateR0(2.0); auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); @@ -1663,25 +1649,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto tuple = LiteralUtil::MakeTuple( - {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + {&one_f32, &vector_half, &matrix_pred, &matrix_pred}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + auto nested_tuple = + LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { - return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + return Literal::CreateFromProto(literal.ToProto()).ValueOrDie(); }; - EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); - EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); - EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); - EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); - EXPECT_EQ(*tuple, to_from_proto(*tuple)); - EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(one_f32, to_from_proto(one_f32)); + EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); + EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); + EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); + EXPECT_EQ(tuple, to_from_proto(tuple)); + EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); - EXPECT_NE(*one_f32, *two_f32); - EXPECT_NE(*one_f32, to_from_proto(*two_f32)); + EXPECT_NE(one_f32, two_f32); + EXPECT_NE(one_f32, to_from_proto(two_f32)); } TEST_F(LiteralUtilTest, InvalidProtoNoValues) { @@ -1690,7 +1678,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) { *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1702,7 +1690,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); } TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { @@ -1714,7 +1702,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1727,7 +1715,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { proto.add_f32s(3.0); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 84 elements in LiteralProto")); } @@ -1740,7 +1728,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { proto.add_s32s(100); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 elements in LiteralProto")); } @@ -1755,7 +1743,7 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); } TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { @@ -1771,7 +1759,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { @@ -1794,17 +1782,17 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, SortSparseElements) { auto literal = LiteralUtil::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); - literal->AppendSparseElement({2, 3, 4}, 2.0); - literal->AppendSparseElement({3, 4, 5}, 3.0); - literal->AppendSparseElement({1, 2, 3}, 1.0); - literal->SortSparseElements(); - ASSERT_EQ(literal->ToString(false), + literal.AppendSparseElement({2, 3, 4}, 2.0); + literal.AppendSparseElement({3, 4, 5}, 3.0); + literal.AppendSparseElement({1, 2, 3}, 1.0); + literal.SortSparseElements(); + EXPECT_EQ(literal.ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1812,60 +1800,56 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { std::vector dimensions = {10, 10, 10}; SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); - ASSERT_EQ( + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), "false"); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(int64{2})); - ASSERT_EQ( + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) + .GetSparseElementAsString(1), + absl::StrCat(int64{2})); + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + .GetSparseElementAsString(1), + absl::StrCat(double{2.0})); + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(static_cast(half{2.0}))); - ASSERT_EQ( - LiteralUtil::CreateSparse( - dimensions, indices, - std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); + .GetSparseElementAsString(1), + absl::StrCat(static_cast(half{2.0}))); + EXPECT_EQ(LiteralUtil::CreateSparse( + dimensions, indices, + std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + .GetSparseElementAsString(1), + absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = LiteralUtil::CreateR0(9); + Literal literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), - /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 356f12ed789d82bc2716b5eafc411e4cafbba2ff..0cb1ae35f4ad31f091063d78ed32c1463be8ee0a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,22 +33,19 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; + // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template -std::unique_ptr ConvertType(LiteralSlice literal) { +Literal ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -56,7 +56,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = MakeUnique(result_shape); + Literal result(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -67,14 +67,14 @@ std::unique_ptr ConvertType(LiteralSlice literal) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); - auto dest = result->data(shape_index); + auto dest = result.data(shape_index); for (int64 i = 0; i < src.size(); ++i) { dest[i] = static_cast(src[i]); } } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); + TF_CHECK_OK(result.CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); } } }); @@ -83,54 +83,52 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { +/* static */ Literal LiteralUtil::CreateFromDimensions( + PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( +/* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( +/* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -/* static */ std::unique_ptr LiteralUtil::CreateToken() { - return MakeUnique(ShapeUtil::MakeTokenShape()); +/* static */ Literal LiteralUtil::CreateToken() { + return Literal(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case C64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -146,30 +144,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case C64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -185,42 +182,36 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case F32: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -233,40 +224,34 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case F32: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -276,34 +261,31 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ Literal LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = MakeUnique( + Literal literal( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( - tensorflow::StringPiece value) { - auto literal = MakeUnique( - ShapeUtil::MakeShape(U8, {static_cast(value.size())})); +/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) { + Literal literal(ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { - literal->Set({i}, value[i]); + literal.Set({i}, value[i]); } return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal) { +/* static */ Literal LiteralUtil::ReshapeSlice( + absl::Span new_dimensions, + absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; for (int64 i = 0; i < new_dimensions.size(); ++i) { new_num_elements *= new_dimensions[i]; @@ -311,13 +293,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = MakeUnique( + Literal new_literal( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); + Shape shape_with_layout = new_literal.shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Copy data into new literal, element-by-element. @@ -328,40 +310,40 @@ std::unique_ptr ConvertType(LiteralSlice literal) { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " @@ -378,101 +360,89 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 8 bit types. case S8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 16 bit types. case BF16: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 32 bit types. case F32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 64 bit types. case C64: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); } } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( - tensorflow::gtl::ArraySlice elements) { +/* static */ Literal LiteralUtil::MakeTuple( + absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements) { +/* static */ Literal LiteralUtil::MakeTupleFromSlices( + absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( - std::vector> elements) { +/* static */ Literal LiteralUtil::MakeTupleOwned( + std::vector elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_shapes.push_back(element->shape()); + element_shapes.push_back(element.shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( - literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } return literal; } /* static */ string LiteralUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); + absl::Span multi_index) { + return StrCat("{", absl::StrJoin(multi_index, ","), "}"); } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e3737a9d0051b32dc0becc19e1849c856a50e52e..2b181621ed92be8952ccec19e0d4229c494b9f47 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -27,6 +27,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -34,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -43,8 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -69,37 +69,34 @@ class LiteralUtil { // The variants not ending with WithLayout use the default XLA layout for the // literal's linear representation in memory. template - static std::unique_ptr CreateR0(NativeT value); + static Literal CreateR0(NativeT value); template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values); - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values); + static Literal CreateR1(absl::Span values); + static Literal CreateR1(const tensorflow::core::Bitmap& values); template - static std::unique_ptr CreateR2( + static Literal CreateR2( std::initializer_list> values); template - static std::unique_ptr CreateR2WithLayout( + static Literal CreateR2WithLayout( std::initializer_list> values, const Layout& layout); template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values); + static Literal CreateR3(std::initializer_list< + std::initializer_list>> + values); template - static std::unique_ptr CreateR3WithLayout( + static Literal CreateR3WithLayout( std::initializer_list< std::initializer_list>> values, const Layout& layout); template - static std::unique_ptr CreateR4( + static Literal CreateR4( std::initializer_list>>> values); template - static std::unique_ptr CreateR4WithLayout( + static Literal CreateR4WithLayout( std::initializer_list>>> values, @@ -140,9 +137,10 @@ class LiteralUtil { // [9, 10, 11]: 4.0 // template - static std::unique_ptr CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort = true); + static Literal CreateSparse(absl::Span dimensions, + SparseIndexArray indices, + absl::Span values, + bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -156,132 +154,120 @@ class LiteralUtil { static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value); + static Literal CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template - static std::unique_ptr CreateFromArray(const Array& values); + static Literal CreateFromArray(const Array& values); template - static std::unique_ptr CreateFromArrayWithLayout( - const Array& values, const Layout& layout); + static Literal CreateFromArrayWithLayout(const Array& values, + const Layout& layout); template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values); + static Literal CreateR2FromArray2D(const Array2D& values); template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); + static Literal CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values); + static Literal CreateR3FromArray3D(const Array3D& values); template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); + static Literal CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values); + static Literal CreateR4FromArray4D(const Array4D& values); template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); + static Literal CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(tensorflow::StringPiece value); + static Literal CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols); + static Literal CreateR2F32Linspace(float from, float to, int64 rows, + int64 cols); // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template - static std::unique_ptr CreateR3Projected( + static Literal CreateR3Projected( std::initializer_list> values, int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. template - static std::unique_ptr CreateR4Projected( + static Literal CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template - static std::unique_ptr MakeIdentityR2(int64 size); + static Literal MakeIdentityR2(int64 size); // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. - static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements); + static Literal MakeTuple(absl::Span elements); - static std::unique_ptr MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements); + static Literal MakeTupleFromSlices(absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // - // std::vector> elements = ...; + // std::vector elements = ...; // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. - static std::unique_ptr MakeTupleOwned( - std::vector> elements); + static Literal MakeTupleOwned(std::vector elements); - // This overload lets you pass a braced list of unique_ptrs to + // This overload lets you pass a braced list of Literals to // MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // - // Simply relying on the MakeTupleOwned(std::vector>) + // Simply relying on the MakeTupleOwned(std::vector) // overload doesn't work because std::initializer_list's elements are always // const. // - // The arguments to this function must all be unique_ptr. + // The arguments to this function must all be Literal. template - static std::unique_ptr MakeTupleOwned( - std::unique_ptr... elements) { - std::array, sizeof...(Ts)> arr{ - std::move(elements)...}; - std::vector> v; + static Literal MakeTupleOwned(Ts... elements) { + std::array arr{std::move(elements)...}; + std::vector v; v.insert(v.begin(), std::make_move_iterator(arr.begin()), std::make_move_iterator(arr.end())); return MakeTupleOwned(std::move(v)); } // Create a constant token literal. Token types have no value. - static std::unique_ptr CreateToken(); + static Literal CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); + static Literal CreateFromDimensions(PrimitiveType primitive_type, + absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32( - const LiteralSlice& bf16_literal); + static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16( - const LiteralSlice& f32_literal); + static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major // layout order. - static std::unique_ptr ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal); + static Literal ReshapeSlice(absl::Span new_dimensions, + absl::Span minor_to_major, + const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -289,9 +275,9 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( + static StatusOr CreateRandomLiteral( const Shape& shape, - const std::function)>& generator); + const std::function)>& generator); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -300,8 +286,8 @@ class LiteralUtil { template < PrimitiveType type, typename E, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, E* engine, + T mean, T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -310,8 +296,8 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -319,51 +305,49 @@ class LiteralUtil { // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); + static string MultiIndexAsString(absl::Span multi_index); }; std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShape( +/* static */ Literal LiteralUtil::CreateR0(NativeT value) { + Literal literal(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); - literal->Set({}, value); + literal.Set({}, value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( - tensorflow::gtl::ArraySlice values) { - auto literal = MakeUnique( +/* static */ Literal LiteralUtil::CreateR1(absl::Span values) { + Literal literal( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ Literal LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, AsInt64Slice(layout.minor_to_major()))); - literal->PopulateR2(values); + literal.PopulateR2(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ Literal LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ Literal LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -388,14 +372,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ Literal LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ Literal LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -426,22 +410,22 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort) { +/* static */ Literal LiteralUtil::CreateSparse( + absl::Span dimensions, SparseIndexArray indices, + absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( + Literal literal(ShapeUtil::MakeShapeWithSparseLayout( primitive_util::NativeToPrimitiveType(), dimensions, indices.max_indices())); - literal->PopulateSparse(indices, values, sort); + literal.PopulateSparse(indices, values, sort); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ Literal LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -449,50 +433,48 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( +/* static */ Literal LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); - literal->PopulateFromArray(values); + literal.PopulateFromArray(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArray( +/* static */ Literal LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ Literal LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ Literal LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ Literal LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -517,7 +499,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ Literal LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -545,21 +527,20 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ Literal LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -568,45 +549,39 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( +/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( primitive_util::NativeToPrimitiveType(), dimensions)); - literal->PopulateWithValue(value); + literal.PopulateWithValue(value); return literal; } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral( +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, - const std::function)>& generator) { + const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = MakeUnique(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); + Literal literal(shape); + TF_RETURN_IF_ERROR(literal.Populate( + [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); + shape, + [&](absl::Span /*indexes*/) { return generator(*engine); }); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 3c74e070da529b7f1431e01fbaf31932f582db44..fcff48b6b18ba115a67f3141a9aea4ca461be55d 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,7 +60,7 @@ MaybeFind(const Collection& collection, if (it == collection.end()) { std::ostringstream os; os << key; - return NotFound("key not found: %s", os.str().c_str()); + return NotFound("key not found: %s", os.str()); } return {it->second}; } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 69ef4f7a2f3ea559a334a11cbe8392b610742bab..4eab4fa4290c270697c00be20840cf4e85459183 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -84,7 +85,7 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) { if (end_of_line == string::npos) { end_of_line = report.size(); } - tensorflow::StringPiece line(report.data() + pos, end_of_line - pos); + absl::string_view line(report.data() + pos, end_of_line - pos); // TODO(b/34779244): Figure out how to do this without the verbose log-line // prefix. The usual way didn't compile on open source. @@ -152,8 +153,8 @@ void MetricTableReport::AppendCategoryTable() { if (text.empty()) { text = "[no category]"; } - tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ", - entry_name_, ")"); + absl::StrAppend(&text, " (", category.entries.size(), " ", entry_name_, + ")"); AppendTableRow(text, category.metric_sum, metric_sum); // Show the top entries in the category. @@ -177,9 +178,9 @@ void MetricTableReport::AppendCategoryTable() { } const int64 remaining_categories = categories.size() - categories_shown; if (remaining_categories > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_categories, - " more categories)"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_categories, " more categories)"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -206,9 +207,9 @@ void MetricTableReport::AppendEntryTable() { } const int64 remaining_entries = entries_.size() - entries_shown; if (remaining_entries > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_entries, - " more ", entry_name_, ")"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_entries, " more ", entry_name_, ")"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -241,10 +242,10 @@ double MetricTableReport::UnaccountedMetric() { string MetricTableReport::MetricString(double metric) { // Round to integer and stringify. - string s1 = tensorflow::strings::StrCat(std::llround(metric)); + string s1 = absl::StrCat(std::llround(metric)); // Code below commafies the string, e.g. "1234" becomes "1,234". - tensorflow::StringPiece sp1(s1); + absl::string_view sp1(s1); string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. @@ -263,8 +264,7 @@ string MetricTableReport::MetricString(double metric) { } string MetricTableReport::MetricPercent(double metric) { - return tensorflow::strings::Printf("%5.2f%%", - metric / expected_metric_sum_ * 100.0); + return absl::StrFormat("%5.2f%%", metric / expected_metric_sum_ * 100.0); } } // namespace xla diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index 818fb1d3fe0b8bbe1a8eba363ff6445e2f3df9d2..062d8ed99b213535ad39d840aaaf10a6fe0da84c 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -18,9 +18,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -108,7 +107,7 @@ class MetricTableReport { // Append all parameters to the report. template void AppendLine(Args... args) { - tensorflow::strings::StrAppend(&report_, std::forward(args)..., "\n"); + absl::StrAppend(&report_, std::forward(args)..., "\n"); } // Represents a set of entries with the same category_text. diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 6b7fd10d63f8f97b0e0bf7570488c06323368d75..0f86f9f35e105713aa3072a9ebf572d33d35d66d 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -19,15 +19,15 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file) PackedLiteralReader::~PackedLiteralReader() { delete file_; } -StatusOr> PackedLiteralReader::Read( - const Shape& shape, const Layout* layout) { +StatusOr PackedLiteralReader::Read(const Shape& shape, + const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); @@ -54,17 +54,17 @@ StatusOr> PackedLiteralReader::Read( if (shape.element_type() != F32) { return Unimplemented( "not yet implemented element type for packed literal reading: %s", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } - auto result = MakeUnique(literal_shape); - result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + Literal result(literal_shape); + result.PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - tensorflow::gtl::ArraySlice field = result->data(); - char* data = tensorflow::bit_cast(field.data()); + absl::Span field = result.data(); + char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 98dccaa9a246520bf60217b96d67a13a24c34b4a..d6d2ff1521bab341b166c4f5c1dc0917e28573d8 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -41,8 +41,7 @@ class PackedLiteralReader { // // Layout is optional. If it is not provided, no layout is set on the literal // that is produced. - StatusOr> Read(const Shape& shape, - const Layout* layout = nullptr); + StatusOr Read(const Shape& shape, const Layout* layout = nullptr); // Returns whether the input file has been fully exhausted; i.e. all available // packed literals have been read and we're at the end of the file. diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index 787725e884c810fd724ab88ad7d4beaf3e0a6cc7..b507a2ef79f1d7e9ae632744675dddf574490805 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) { return safe_file_name; } +std::pair>*> +GetDirectoryExpanders() { + static auto* mutex = new tensorflow::mutex; + static auto* singleton = new std::vector>; + return {mutex, singleton}; +} + +// Runs all the directory expanders over x and returns the result. +string Expand(string x) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + for (const auto& f : *pair.second) { + x = f(x); + } + return x; +} + } // namespace Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string expanded_dir = Expand(directory); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); + const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } +void RegisterDirectoryExpander(const std::function& expander) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + pair.second->push_back(expander); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 3667621367c7639c40ff17aee7b77305d4d34e33..f22fc8b8499dd4a5329276040331a2ed9e89bea9 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); +// Registers a function that may either expand a dirpath or forward the original +// dirpath along as-is. +void RegisterDirectoryExpander(const std::function& expander); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index c8f2d65c223ccfe20862954c224d016cca421812..f0d84646b9f01ad3ad209073f13b7b3ec21635d1 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -39,6 +39,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -59,6 +62,8 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 8246f76d3443d58f4174cc4f86100f54d6b46928..cd5fd330298fb0ff158e232dac121f8ffb271218 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, return client->TransferToInfeedLocal(literal, device_ordinal); } -StatusOr> TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number) { +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; LocalClient* client = GetOrCreateLocalClient(); @@ -137,14 +137,12 @@ static StatusOr ToBuffer(LocalClient* client, /* static */ StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, - const tensorflow::gtl::optional& shape_with_layout) { + const Literal& argument, const absl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); StatusOr buf = [&] { if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, relaid); } return ToBuffer(client, /*device_ordinal=*/0, argument); }(); @@ -152,7 +150,7 @@ StatusOr LocalShapedBuffer::FromLiteral( return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -StatusOr> LocalShapedBuffer::ToLiteral() const { +StatusOr LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -161,16 +159,16 @@ CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr> CompiledLocalComputation::Execute( +StatusOr CompiledLocalComputation::Execute( const std::vector& arguments, - const std::vector>& shapes_with_layout) { + const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; // Each replica populates a StatusOr result, but only replica zero actually // retrieves its literal value. - std::vector>> results(GetReplicaCount()); + std::vector> results(GetReplicaCount()); { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", GetReplicaCount()); @@ -194,14 +192,13 @@ StatusOr> CompiledLocalComputation::Execute( scoped_buffers.reserve(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { const Literal& argument = arguments[i]; - const tensorflow::gtl::optional& shape_with_layout = + const absl::optional& shape_with_layout = shapes_with_layout[i]; StatusOr pushed; if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, relaid); } else { pushed = ToBuffer(client, device_ordinal, argument); } @@ -252,7 +249,7 @@ StatusOr> CompiledLocalComputation::Execute( return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", - replica, statusor.status().ToString().c_str()); + replica, statusor.status().ToString()); } } @@ -260,7 +257,7 @@ StatusOr> CompiledLocalComputation::Execute( } LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles) { + absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); std::vector argument_buffers; @@ -370,8 +367,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { } LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { + const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } @@ -381,14 +377,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } @@ -396,10 +392,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { return xla::CrossReplicaSum(operand.op()); } -LocalOp LocalComputationBuilder::Slice( - const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } @@ -412,7 +408,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } @@ -422,8 +418,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice( return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, int64 dimension) { +LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -434,18 +430,16 @@ LocalOp LocalComputationBuilder::ConcatInDim( LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { +LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -472,13 +466,14 @@ LocalOp LocalComputationBuilder::DotGeneral( LocalOp LocalComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, - lhs_dilation, rhs_dilation, dimension_numbers); + lhs_dilation, rhs_dilation, dimension_numbers, + feature_group_count); } LocalOp LocalComputationBuilder::ConvertElementType( @@ -491,9 +486,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType( return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { +LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -503,19 +497,18 @@ LocalOp LocalComputationBuilder::Call( } LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Map(absl::Span operands, + const LocalComputation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -529,7 +522,7 @@ LocalOp LocalComputationBuilder::Map( LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } @@ -537,9 +530,9 @@ LocalOp LocalComputationBuilder::Reduce( LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), window_dimensions, window_strides, padding); @@ -575,6 +568,16 @@ StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } +LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { + return xla::Sort(operand.op(), absl::nullopt, dimension); +} + +LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, + int64 dimension) { + return xla::Sort(keys.op(), values.op(), dimension); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, @@ -590,10 +593,10 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD_UNOP(method_name) \ _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) -#define _FORWARD_BINOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + absl::Span broadcast_dimensions), \ (lhs.op(), rhs.op(), broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ @@ -640,7 +643,6 @@ _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) -_FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) @@ -688,8 +690,7 @@ StatusOr DestructureLocalShapedBufferTuple( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", ShapeUtil::HumanString( - local_shaped_buffer->shaped_buffer()->on_device_shape()) - .c_str()); + local_shaped_buffer->shaped_buffer()->on_device_shape())); } DeviceMemoryAllocator* allocator = diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a568c24c6376e1fe17f5e5a4f6626bf0970985a3..2166bb6721ca380f3180a8802e4922f2e9e45945 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace swig { @@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); // Transfers a literal of the given shape from the outfeed of the given replica. // // The replica number is resolved to an appropriate device ordinal. -StatusOr > TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number); +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number); // Wraps a ScopedShapedBuffer produced by copying a literal "to // device," i.e. copying a literal to a scoped buffer via the local @@ -60,13 +60,12 @@ StatusOr > TransferFromOutfeedLocalReplica( class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, - const tensorflow::gtl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - StatusOr > ToLiteral() const; + StatusOr ToLiteral() const; // Transfers ownership of the encapsulated ShapedBuffer to the caller, // analogous to std::unique_ptr::release(). @@ -118,12 +117,12 @@ class CompiledLocalComputation { // with optionally-specified argument layouts. The literals will be // re-laid out according to the corresponding elements of // shapes_with_layout. - StatusOr > Execute( + StatusOr Execute( const std::vector& arguments, - const std::vector >& shapes_with_layout); + const std::vector >& shapes_with_layout); LocalShapedBuffer* ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles); + absl::Span argument_handles); private: std::unique_ptr executable_; @@ -200,46 +199,41 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); LocalOp Broadcast(const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config); - LocalOp Reshape(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, absl::Span dimensions, + absl::Span new_sizes); - LocalOp Collapse(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); LocalOp CrossReplicaSum(const LocalOp& operand); - LocalOp Slice(const LocalOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); LocalOp SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices); - LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter); - LocalOp Tuple(tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(absl::Span elements); LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); @@ -250,11 +244,12 @@ class LocalComputationBuilder { LocalOp ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span > padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); @@ -263,28 +258,27 @@ class LocalComputationBuilder { PrimitiveType new_element_type); LocalOp Call(const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); LocalOp Transpose(const LocalOp& operand, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); - LocalOp Rev(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, absl::Span dimensions); - LocalOp Map(tensorflow::gtl::ArraySlice operands, + LocalOp Map(absl::Span operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape); @@ -301,6 +295,11 @@ class LocalComputationBuilder { StatusOr IsConstant(const LocalOp& operand); + LocalOp Sort(const LocalOp& operand, int64 dimension); + + LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, + int64 dimension); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ @@ -312,7 +311,7 @@ class LocalComputationBuilder { #define _FORWARD_BINOP(method_name) \ _FORWARD(method_name, LocalOp, \ (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) + absl::Span broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ _FORWARD(method_name, LocalOp, \ @@ -357,7 +356,6 @@ class LocalComputationBuilder { _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) - _FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 5d5a955bfee35b38a61b9a9f792c1b31259ce044..521490e76c138553c5cc6895412eadb35a939881 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,15 +22,15 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ArraySlice <- sequence of int -// ArraySlice <- sequence of LocalOp +// Span <- sequence of int +// Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int -// ArraySlice> <- sequence of int pairs +// Span> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto @@ -109,10 +109,12 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" @@ -154,8 +156,8 @@ bool HandleStringAttribute(PyObject* o, return true; // The attribute is None, which we consider ok. } if (!PyString_Check(attr)) { - string message = tensorflow::strings::Printf("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr).c_str()); + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. @@ -214,9 +216,9 @@ tensorflow::ImportNumpy(); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if ($1.ok()) { - std::unique_ptr value = $1.ConsumeValueOrDie(); + Literal value = $1.ConsumeValueOrDie(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -265,9 +267,9 @@ tensorflow::ImportNumpy(); $result = Py_None; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -297,9 +299,9 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice( +%typemap(in) absl::Span( std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -321,7 +323,7 @@ tensorflow::ImportNumpy(); // LocalShapedBuffer* -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -344,25 +346,25 @@ tensorflow::ImportNumpy(); // Literal -%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { +%typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); SWIG_fail; } - $1 = literal_status.ValueOrDie().get(); + $1 = &literal_status.ValueOrDie(); } -%typemap(out) std::unique_ptr { +%typemap(out) Literal { $result = numpy::PyObjectFromXlaLiteral(*$1); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); + $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); } %typemap(in) const std::vector& (std::vector temps) { @@ -373,13 +375,13 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); SWIG_fail; } - temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + temps.push_back(literal_status.ConsumeValueOrDie()); Py_DECREF(o); } $1 = &temps; @@ -409,10 +411,10 @@ tensorflow::ImportNumpy(); $1 = &temp; } -%typemap(in) const tensorflow::gtl::optional& ( - tensorflow::gtl::optional temp) { +%typemap(in) const absl::optional& ( + absl::optional temp) { if ($input == Py_None) { - temp = tensorflow::gtl::nullopt; + temp = absl::nullopt; $1 = &temp; } else { StatusOr statusor = numpy::XlaShapeFromPyShape($input); @@ -448,8 +450,8 @@ tensorflow::ImportNumpy(); $1 = &temps; } -%typemap(in) const std::vector >& ( - std::vector > temps) { +%typemap(in) const std::vector >& ( + std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; @@ -458,7 +460,7 @@ tensorflow::ImportNumpy(); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); if (o == Py_None) { - temps.push_back(tensorflow::gtl::nullopt); + temps.push_back(absl::nullopt); } else { StatusOr statusor = numpy::XlaShapeFromPyShape(o); Py_DECREF(o); @@ -494,9 +496,9 @@ tensorflow::ImportNumpy(); $1 = static_cast(value); } -// ArraySlice> +// Span> -%typemap(in) tensorflow::gtl::ArraySlice > +%typemap(in) absl::Span > (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -896,7 +898,7 @@ tensorflow::ImportNumpy(); if (o != Py_None) { StatusOr statusor = numpy::XlaShapeFromPyShape(o); if (!statusor.ok()) { - PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); + PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); SWIG_fail; } @@ -1011,6 +1013,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::LocalComputationBuilder::SortKeyVal; %unignore xla::swig::LocalComputationBuilder::Sqrt; %unignore xla::swig::LocalComputationBuilder::Rsqrt; %unignore xla::swig::LocalComputationBuilder::Square; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 6f665faf61b25b23a32ce4d0a012543ba18d7e64..b0aa024c7474cf8e6934432b2f364be464714999 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -149,9 +151,7 @@ static int NumpyTypenum(PyObject* o) { // // NOTE: this is an internal helper for conversion to a C++, and so decrefs r. static string ExtractStringAndDecref(PyObject* r) { - auto error = [r] { - return tensorflow::strings::Printf("", r); - }; + auto error = [r] { return absl::StrFormat("", r); }; if (r == nullptr) { return error(); } @@ -191,8 +191,8 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { PyObject* result = PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to call method of shape object:", method)); + return error( + absl::StrCat("Failed to call method of shape object:", method)); } return result; }; @@ -281,15 +281,15 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { // Helper that retrieves the member with attr_name, stringifies it if is not // None, and returns it as a C++ string. -static tensorflow::gtl::optional GetAttrAsString( - PyObject* o, const string& attr_name) { +static absl::optional GetAttrAsString(PyObject* o, + const string& attr_name) { if (!PyObject_HasAttrString(o, attr_name.c_str())) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); if (attr == Py_None) { Py_DECREF(attr); - return tensorflow::gtl::nullopt; + return absl::nullopt; } string result = PyObjectCppStr(attr); Py_DECREF(attr); @@ -298,48 +298,46 @@ static tensorflow::gtl::optional GetAttrAsString( // Helper that retrieves the member with attr_name, checks that it is an integer // if it is not None, and returns it as an int32 value. -static tensorflow::gtl::optional GetAttrAsInt32( - PyObject* o, const string& attr_name) { +static absl::optional GetAttrAsInt32(PyObject* o, + const string& attr_name) { if (!PyObject_HasAttrString(o, attr_name.c_str())) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); if (attr == Py_None) { Py_DECREF(attr); - return tensorflow::gtl::nullopt; + return absl::nullopt; } if (!CheckPyIntOrLong(attr)) { Py_DECREF(attr); - return tensorflow::gtl::nullopt; + return absl::nullopt; } long value = PyIntOrPyLongToLong(attr); // NOLINT Py_DECREF(attr); if (value == -1 && PyErr_Occurred() != nullptr) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } if (static_cast(value) != value) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } return value; } StatusOr OpMetadataFromPyObject(PyObject* o) { OpMetadata result; - tensorflow::gtl::optional op_type = GetAttrAsString(o, "op_type"); + absl::optional op_type = GetAttrAsString(o, "op_type"); if (op_type.has_value()) { result.set_op_type(op_type.value()); } - tensorflow::gtl::optional op_name = GetAttrAsString(o, "op_name"); + absl::optional op_name = GetAttrAsString(o, "op_name"); if (op_name.has_value()) { result.set_op_name(op_name.value()); } - tensorflow::gtl::optional source_file = - GetAttrAsString(o, "source_file"); + absl::optional source_file = GetAttrAsString(o, "source_file"); if (source_file.has_value()) { result.set_source_file(source_file.value()); } - tensorflow::gtl::optional source_line = - GetAttrAsInt32(o, "source_line"); + absl::optional source_line = GetAttrAsInt32(o, "source_line"); if (source_line.has_value()) { result.set_source_line(source_line.value()); } @@ -370,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } } -StatusOr> XlaLiteralFromPyObject(PyObject* o) { +StatusOr XlaLiteralFromPyObject(PyObject* o) { if (PyTuple_Check(o)) { int num_elements = PyTuple_Size(o); - std::vector> elements; + std::vector elements; elements.reserve(num_elements); for (int i = 0; i < num_elements; i++) { PyObject* element = PyTuple_GetItem(o, i); @@ -391,8 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { int np_type = PyArray_TYPE(py_array); auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); - TF_RETURN_IF_ERROR( - CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); return std::move(literal); } else { return InvalidArgument( diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index a67c93a4fb7413f9bbcb9afd92c36fd118836e1f..40ff2d9ad214cc4dcad42234fa296834cbc92882 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -25,9 +25,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/python/lib/core/numpy.h" namespace xla { @@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // To avoid transferring ownership of the data buffers that underlie // PyArrays and XLA literals, this function makes deep copies of all // array data. -StatusOr > XlaLiteralFromPyObject(PyObject* o); +StatusOr XlaLiteralFromPyObject(PyObject* o); // The following functions copy array data from the buffers underlying Numpy // ndarrays into those underlying XLA literals, and vice versa. diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a2c6fc344d192265d536ef7e23ad5c6d7c847014..bb303c5678a2cac9a9e78925e857ab25c0c6d9be 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -105,7 +105,6 @@ _UNARY_OPS = [ 'Square', 'Reciprocal', 'Neg', - 'Sort', 'Erf', 'Erfc', 'ErfInv', @@ -1110,7 +1109,7 @@ class ComputationBuilder(object): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) return self._client.DotGeneral(lhs, rhs, dimension_numbers) - def Conv(self, lhs, rhs, window_strides, padding): + def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): """Enqueues a Conv operation onto the computation. Args: @@ -1118,6 +1117,7 @@ class ComputationBuilder(object): rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the Conv operation. """ @@ -1126,10 +1126,11 @@ class ComputationBuilder(object): self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), - (), dimension_numbers) + (), dimension_numbers, + feature_group_count) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation): + lhs_dilation, rhs_dilation, feature_group_count=1): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: @@ -1139,6 +1140,7 @@ class ComputationBuilder(object): padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. rhs_dilation: length-N array-like of dilation factors. + feature_group_count: number of feature groups for grouped convolution. Returns: A ComputationdataHandle representing the added ConvWithGeneralPadding op. @@ -1146,7 +1148,8 @@ class ComputationBuilder(object): dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1164,7 +1167,8 @@ class ComputationBuilder(object): return dimension_numbers def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers): + rhs_dilation, dimension_numbers, + feature_group_count=1): """Enqueues a ConvGeneralDilated operation onto the computation. Args: @@ -1191,6 +1195,7 @@ class ComputationBuilder(object): labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not 'I' or 'O'. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the ConvGenralDilated operation. """ @@ -1216,7 +1221,16 @@ class ComputationBuilder(object): key=lambda i: rhs_spec.index(out_spec[i]))) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) + + def Sort(self, operand, dimension=-1): + """Enqueues a sort operation onto the computation.""" + return self._client.Sort(operand, dimension) + + def SortKeyVal(self, keys, values, dimension=-1): + """Enqueues a key-value sort operation onto the computation.""" + return self._client.SortKeyVal(keys, values, dimension) def _forward_methods_to_local_builder(): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fd98e19457f61aade947aa354d2e415148d127f6..82103f03132e45ff822ce1ebcc2be47b24f5869f 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + feature_group_count = 2 + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]], + [[0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a803520876952a0ab67ecb827b1f256c915335f9..ceb5e74db7c3b9305e9d77068df9ae0a3690af8a 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -43,7 +44,7 @@ std::unique_ptr> MatmulArray2DImpl( int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = absl::make_unique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). @@ -77,7 +78,8 @@ std::unique_ptr> MatmulArray2DImpl( /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( const Array2D& input) { - auto result = MakeUnique>(input.height(), input.width()); + auto result = + absl::make_unique>(input.height(), input.width()); for (int64 rowno = 0; rowno < input.height(); ++rowno) { for (int64 colno = 0; colno < input.height(); ++colno) { (*result)(rowno, colno) = input(rowno, colno); @@ -106,17 +108,15 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. Array4D a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); - a4dlhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); - }); + a4dlhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); + }); Array4D a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); - a4drhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); - }); + a4drhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); + }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; dnums2d.add_input_spatial_dimensions(3); @@ -126,13 +126,12 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); - auto convr3 = MakeUnique>(convr4->planes(), convr4->depth(), - convr4->height()); - convr4->Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; - }); + auto convr3 = absl::make_unique>( + convr4->planes(), convr4->depth(), convr4->height()); + convr4->Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; + }); return convr3; } @@ -187,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -201,7 +199,7 @@ ReferenceUtil::ReduceWindow1DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0]); + auto result = absl::make_unique>(window_counts[0]); // Do a full 1D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -219,10 +217,10 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { +ReferenceUtil::ReduceWindow1DAdd(absl::Span operand, float init, + absl::Span window, + absl::Span stride, + Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; return ReduceWindow1DGeneric( @@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd( ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -247,7 +244,8 @@ ReferenceUtil::ReduceWindow2DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1]); + auto result = + absl::make_unique>(window_counts[0], window_counts[1]); // Do a full 2D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -272,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( - const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -283,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( - const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -296,8 +292,8 @@ ReferenceUtil::ReduceWindow2DGeneric( WindowCount(dim_lengths[i], window[i], stride[i], padding); pad_low[i] = padding_both[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2]); for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -331,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + absl::Span window, absl::Span stride, + Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -344,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -358,8 +353,8 @@ ReferenceUtil::ReduceWindow4DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2], window_counts[3]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2], window_counts[3]); // Do a full 4D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -398,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( - const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -421,13 +415,15 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::SelectAndScatter4DGePlus( - const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding) { +ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, + const Array4D& source, + float init, + absl::Span window, + absl::Span stride, + bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; - auto result = MakeUnique>(operand.n1(), operand.n2(), - operand.n3(), operand.n4()); + auto result = absl::make_unique>(operand.n1(), operand.n2(), + operand.n3(), operand.n4()); std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -526,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)); std::vector> paddings = MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, @@ -543,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim; dim.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0))); dim.set_stride(kernel_stride.first); dim.set_padding_low(paddings[0].first); dim.set_padding_high(paddings[0].second); @@ -553,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim2; dim2.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1))); dim2.set_stride(kernel_stride.second); dim2.set_padding_low(paddings[1].first); dim2.set_padding_high(paddings[1].second); @@ -561,35 +557,39 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = - ShapeInference::InferConvolveShape(lhs_literal->shape(), - rhs_literal->shape(), window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = ShapeInference::InferConvolveShape( + lhs_literal.shape(), rhs_literal.shape(), + /*feature_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; - std::unique_ptr result_literal = + Literal result_literal = evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); + CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); auto result = - MakeUnique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal.shape().dimensions(0), + result_literal.shape().dimensions(1), + result_literal.shape().dimensions(2), + result_literal.shape().dimensions(3)); - result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { - *value = result_literal->Get(indices); + result->Each([&](absl::Span indices, float* value) { + *value = result_literal.Get(indices); }); return result; @@ -601,7 +601,7 @@ ReferenceUtil::ReduceToColArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < rows; ++i) { float acc = init; for (int64 j = 0; j < cols; ++j) { @@ -618,7 +618,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < cols; ++i) { float acc = init; for (int64 j = 0; j < rows; ++j) { @@ -630,8 +630,7 @@ ReferenceUtil::ReduceToRowArray2D( } /*static*/ std::vector ReferenceUtil::Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); @@ -674,8 +673,8 @@ ReferenceUtil::ReduceToRowArray2D( /* static */ std::unique_ptr> ReferenceUtil::Broadcast1DTo4D( const std::vector& array, const std::vector& bounds, int64 broadcast_from_dim) { - auto result = - MakeUnique>(bounds[0], bounds[1], bounds[2], bounds[3]); + auto result = absl::make_unique>(bounds[0], bounds[1], + bounds[2], bounds[3]); for (int64 i = 0; i < result->n1(); ++i) { for (int64 j = 0; j < result->n2(); ++j) { for (int64 k = 0; k < result->n3(); ++k) { @@ -704,13 +703,12 @@ ReferenceUtil::ReduceToRowArray2D( } /* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function) { CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); int64 cols = dims[0] == 2 ? array.n2() : array.n3(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); result->Fill(init); for (int i0 = 0; i0 < array.n1(); ++i0) { for (int i1 = 0; i1 < array.n2(); ++i1) { @@ -730,7 +728,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j)); @@ -746,7 +744,7 @@ ReferenceUtil::ReduceToRowArray2D( CHECK_EQ(lhs.width(), rhs.width()); int64 rows = lhs.height(); int64 cols = rhs.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); @@ -760,7 +758,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j), i, j); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 8fa6961d197dce519cf151283b8bc0836a4615c0..8654fbb9b5e16c5ac13cb29aafeef8d142dbe39f 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -22,14 +22,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,8 @@ class ReferenceUtil { template static std::unique_ptr> TransposeArray2D( const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); + auto result = + absl::make_unique>(operand.width(), operand.height()); for (int64 w = 0; w < operand.width(); ++w) { for (int64 h = 0; h < operand.height(); ++h) { (*result)(w, h) = operand(h, w); @@ -143,8 +144,7 @@ class ReferenceUtil { // Returns the result of reducing the 4D array to a vector, reducing away // the dimensions specified in dims. static std::vector Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function); // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`. @@ -155,8 +155,7 @@ class ReferenceUtil { // Returns the result of reducing the 3D array to a 2D array, reducing away // the dimensions specified in dims. static std::unique_ptr> Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function); // Applies map_function to each element in the input (2D array) and returns @@ -178,47 +177,41 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + absl::Span operand, float init, + absl::Span window, absl::Span stride, + Padding padding); static std::unique_ptr> ReduceWindow2DAdd( - const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( - const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( - const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + absl::Span window, absl::Span stride, + Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -231,8 +224,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding); + absl::Span window, absl::Span stride, + bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -242,7 +235,7 @@ class ReferenceUtil { const Array2D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 2); - auto result = MakeUnique>( + auto result = absl::make_unique>( concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(), concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2()); for (int64 i0 = 0; i0 < result->n1(); ++i0) { @@ -276,7 +269,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2]); + auto result = + absl::make_unique>(out_dims[0], out_dims[1], out_dims[2]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -310,8 +304,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2], - out_dims[3]); + auto result = absl::make_unique>(out_dims[0], out_dims[1], + out_dims[2], out_dims[3]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -332,8 +326,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D( - const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + static std::vector ClampSlice1D(absl::Span input, int64 start, + int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { @@ -355,9 +349,9 @@ class ReferenceUtil { CHECK_LE(limits[1], input.n2()); CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { (*result)(i0, i1) = @@ -381,10 +375,10 @@ class ReferenceUtil { CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { @@ -415,11 +409,11 @@ class ReferenceUtil { CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); CHECK_GE(strides[3], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2]), - CeilOfRatio(limits[3] - starts[3], strides[3])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2]), + CeilOfRatio(limits[3] - starts[3], strides[3])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -460,8 +454,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& input, F&& map_function) { - auto result = MakeUnique>(input.planes(), input.depth(), - input.height(), input.width()); + auto result = absl::make_unique>( + input.planes(), input.depth(), input.height(), input.width()); for (int64 plane = 0; plane < input.planes(); ++plane) { for (int64 depth = 0; depth < input.depth(); ++depth) { for (int64 height = 0; height < input.height(); ++height) { @@ -495,8 +489,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& lhs, const Array4D& rhs, F&& map_function) { - auto result = MakeUnique>(lhs.planes(), lhs.depth(), - lhs.height(), lhs.width()); + auto result = absl::make_unique>(lhs.planes(), lhs.depth(), + lhs.height(), lhs.width()); for (int64 plane = 0; plane < lhs.planes(); ++plane) { for (int64 depth = 0; depth < lhs.depth(); ++depth) { for (int64 height = 0; height < lhs.height(); ++height) { @@ -530,7 +524,7 @@ class ReferenceUtil { int64 out1 = in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - auto result = MakeUnique>(out0, out1); + auto result = absl::make_unique>(out0, out1); result->Fill(pad); int64 o0 = low_padding0; for (int64 i0 = 0; i0 < in0; ++i0) { @@ -631,7 +625,7 @@ class ReferenceUtil { Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], output_bounds[3]); result.Each( - [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + [&](absl::Span indices, NativeT* value) { for (int i = 0; i < 4; ++i) { bool in_low_padding = indices[i] < pad_low[i]; bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; @@ -669,7 +663,7 @@ class ReferenceUtil { static std::unique_ptr> ApplyElementwise2D( F&& f, const Array2D& array1, const Array2D&... arrays) { AssertSameSize2D(array1, arrays...); - auto result = MakeUnique>(array1.n1(), array1.n2()); + auto result = absl::make_unique>(array1.n1(), array1.n2()); for (int64 i = 0; i < array1.n1(); ++i) { for (int64 j = 0; j < array1.n2(); ++j) { (*result)(i, j) = f(array1(i, j), arrays(i, j)...); diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 8091bed4996a753649a5ecedda69a1ae48fb5897..a1b0f4045ff071454451f9fe3942ac974f4f47ac 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,7 +36,7 @@ namespace { class ReferenceUtilTest : public ::testing::Test { protected: ReferenceUtilTest() { - matrix_ = MakeUnique>(rows_, cols_); + matrix_ = absl::make_unique>(rows_, cols_); // [1.f 2.f 3.f] // [4.f 5.f 6.f] for (int64 i = 0; i < rows_; ++i) { @@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, actual_literal, ErrorSpec(0.0001)); } @@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, actual_literal, ErrorSpec(0.0001)); } @@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); - LiteralTestUtil::ExpectR1Equal({0}, *result); + LiteralTestUtil::ExpectR1Equal({0}, result); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, ErrorSpec(0.0001)); } @@ -108,12 +108,12 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); @@ -121,13 +121,13 @@ TEST_F(ReferenceUtilTest, MapArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto subtract_index = [](float value, int64 plane, int64 depth, int64 height, int64 width) { @@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray3D) { @@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal, ErrorSpec(0.0001)); } @@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, - *actual_literal, ErrorSpec(0.0001)); + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray4D) { @@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray4D) { @@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { @@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 44b22a5586dee3f7dd8ea0edbf9deb2090986ac8..3abb3855a42b8b5222115262448d359da3a80e87 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -34,15 +34,25 @@ cc_library( ], ) -tf_cc_binary( - name = "grpc_service_main_cpu", +cc_library( + name = "grpc_service_main_library", srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_binary( + name = "grpc_service_main_cpu", + deps = [ + ":grpc_service_main_library", + "//tensorflow/compiler/xla/service:cpu_plugin", ], ) @@ -62,6 +72,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 67886761813f0bb45a600661b017be91ffeade73..84fe5b17d10fba8c9f44314bec2b827e98ff6b33 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/subprocess.h" @@ -46,7 +46,7 @@ class GRPCClientTestBase : public ::testing::Test { int port = tensorflow::internal::PickUnusedPortOrDie(); subprocess_.SetProgram( service_main_path, - {service_main_path, tensorflow::strings::Printf("--port=%d", port)}); + {service_main_path, absl::StrFormat("--port=%d", port)}); subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_DUPPARENT); subprocess_.SetChannelAction(tensorflow::CHAN_STDERR, @@ -54,9 +54,8 @@ class GRPCClientTestBase : public ::testing::Test { CHECK(subprocess_.Start()); LOG(INFO) << "Launched subprocess"; - auto channel = - ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port), - ::grpc::InsecureChannelCredentials()); + auto channel = ::grpc::CreateChannel(absl::StrFormat("localhost:%d", port), + ::grpc::InsecureChannelCredentials()); channel->WaitForConnected(gpr_time_add( gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); LOG(INFO) << "Channel to server is connected on port " << port; @@ -96,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal expected_literal = LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal, ErrorSpec(0.0001))); } diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index c68c857c304138ff4318e243f66547c6acce1005..522ab99fb1feff69610af887b58f197211cdb21f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -18,8 +18,9 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,8 +30,15 @@ namespace { int RealMain(int argc, char** argv) { int32 port = 1685; + bool any_address = false; + string platform_str; std::vector flag_list = { - tensorflow::Flag("port", &port, "port to listen on"), + tensorflow::Flag("platform", &platform_str, + "The XLA platform this service should be bound to"), + tensorflow::Flag("port", &port, "The TCP port to listen on"), + tensorflow::Flag( + "any", &any_address, + "Whether to listen to any host address or simply localhost"), }; string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); + se::Platform* platform = nullptr; + if (!platform_str.empty()) { + platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie(); + } std::unique_ptr service = - xla::GRPCService::NewService().ConsumeValueOrDie(); + xla::GRPCService::NewService(platform).ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(tensorflow::strings::Printf("localhost:%d", port)); + string server_address( + absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port)); + builder.SetMaxReceiveMessageSize(INT_MAX); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); LOG(INFO) << "Server listening on " << server_address; server->Wait(); - return 0; } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 528b7fdfd3c39cc3a56afc92474dbae976a08ba8..e800cf470cfd129f93c2a1be586e03bebcaec987 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -86,6 +87,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -99,9 +101,11 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -120,6 +124,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -156,6 +161,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -175,6 +181,10 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -191,6 +201,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -226,6 +237,7 @@ cc_library( hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_query", ":shape_inference", "//tensorflow/compiler/xla:literal", @@ -237,6 +249,12 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -263,6 +281,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -275,6 +294,7 @@ cc_library( "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", ], hdrs = [ @@ -287,6 +307,7 @@ cc_library( "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", ], deps = [ @@ -311,6 +332,13 @@ cc_library( "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -326,6 +354,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -336,8 +365,11 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/utility", ], ) @@ -363,6 +395,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/types:span", ], ) @@ -375,6 +408,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -389,7 +423,8 @@ cc_library( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -419,6 +454,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -449,6 +485,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -466,6 +505,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -514,9 +554,11 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -535,6 +577,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -550,8 +593,10 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -570,10 +615,13 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -613,7 +661,12 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -646,6 +699,10 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -668,6 +725,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -718,6 +776,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -735,6 +797,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -752,9 +815,11 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", + ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -765,6 +830,10 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) @@ -783,6 +852,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/types:span", ], ) @@ -812,6 +882,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -830,6 +903,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -846,6 +921,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -863,6 +939,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -873,6 +952,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -907,6 +987,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -916,12 +998,15 @@ tf_cc_test( deps = [ ":buffer_liveness", ":hlo", + ":hlo_dataflow_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -938,8 +1023,8 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_memory_scheduler", ":hlo_proto", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -949,6 +1034,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -963,8 +1052,8 @@ tf_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -974,8 +1063,11 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -995,6 +1087,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1005,14 +1099,15 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dataflow_analysis", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1030,6 +1125,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1046,8 +1142,43 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_module_group_test", + srcs = ["hlo_module_group_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_module_group", + ":hlo_module_group_metadata", + ":hlo_parser", + ":hlo_proto", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -1058,12 +1189,15 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -1073,6 +1207,7 @@ cc_library( hdrs = ["hlo_module_group_util.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_group_metadata", ":hlo_reachability", "//tensorflow/compiler/xla:status", @@ -1081,17 +1216,41 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_schedule_test", + srcs = ["hlo_schedule_test.cc"], + deps = [ + ":heap_simulator", + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) cc_library( - name = "hlo_scheduling", - srcs = ["hlo_scheduling.cc"], - hdrs = ["hlo_scheduling.h"], + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], deps = [ ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1100,24 +1259,27 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) tf_cc_test( - name = "hlo_scheduling_test", - srcs = ["hlo_scheduling_test.cc"], + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], deps = [ - ":buffer_value", ":heap_simulator", ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1141,6 +1303,8 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1166,6 +1330,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1180,6 +1345,9 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1195,8 +1363,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1215,6 +1385,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -1230,6 +1402,22 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1250,8 +1438,10 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1261,6 +1451,7 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_creation_utils", ":hlo_pass", ":hlo_query", @@ -1274,6 +1465,11 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -1283,6 +1479,7 @@ tf_cc_test( deps = [ ":algebraic_simplifier", ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal", @@ -1297,6 +1494,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1308,8 +1507,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1362,6 +1560,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1384,6 +1583,52 @@ tf_cc_test( ], ) +cc_library( + name = "convolution_feature_group_converter", + srcs = ["convolution_feature_group_converter.cc"], + hdrs = ["convolution_feature_group_converter.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "convolution_feature_group_converter_test", + size = "small", + srcs = ["convolution_feature_group_converter_test.cc"], + deps = [ + ":convolution_feature_group_converter", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + +cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1391,10 +1636,12 @@ cc_library( deps = [ ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", + ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -1408,6 +1655,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1507,6 +1755,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1522,6 +1771,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1542,6 +1792,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1556,6 +1807,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -1573,8 +1825,10 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -1594,6 +1848,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1607,6 +1863,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1644,6 +1902,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/types:span", ], ) @@ -1684,6 +1943,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1698,6 +1959,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1729,6 +1991,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1737,6 +2001,9 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_memory_scheduler", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1745,6 +2012,9 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -1760,6 +2030,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1787,6 +2059,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1804,6 +2078,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1822,6 +2099,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1863,6 +2144,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1899,6 +2182,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1919,6 +2203,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1939,6 +2225,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1956,6 +2243,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -1968,7 +2256,6 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1976,6 +2263,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1994,6 +2286,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2026,6 +2319,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -2048,6 +2345,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2067,6 +2365,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2115,7 +2414,10 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2144,21 +2446,23 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", - ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", - ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2175,6 +2479,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2198,6 +2503,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2240,9 +2546,11 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -2254,6 +2562,7 @@ cc_library( ], deps = [ ":hlo", + ":hlo_module_group", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -2279,6 +2588,29 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "hlo_pass_pipeline_test", + srcs = ["hlo_pass_pipeline_test.cc"], + deps = [ + ":hlo", + ":hlo_parser", + ":hlo_pass_pipeline", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -2295,6 +2627,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2313,9 +2646,11 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2332,6 +2667,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2342,12 +2678,14 @@ tf_cc_test( ":hlo", ":hlo_constant_folding", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2363,6 +2701,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2377,6 +2716,8 @@ cc_library( "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -2437,6 +2778,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2483,6 +2825,22 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "maybe_owning_device_memory", + srcs = [ + "maybe_owning_device_memory.cc", + ], + hdrs = [ + "maybe_owning_device_memory.h", + ], + deps = [ + ":device_memory_allocator", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", ], ) @@ -2492,6 +2850,7 @@ cc_library( hdrs = ["elemental_ir_emitter.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2500,11 +2859,14 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2536,10 +2898,11 @@ cc_library( ":computation_layout", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -2552,6 +2915,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2588,8 +2952,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -2599,6 +2963,7 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -2623,6 +2988,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], alwayslink = 1, ) @@ -2639,6 +3007,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2720,9 +3089,9 @@ cc_library( hdrs = ["stream_pool.h"], deps = [ "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -2820,6 +3189,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -2840,7 +3211,7 @@ cc_library( hdrs = ["tuple_util.h"], deps = [ ":hlo", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -2866,7 +3237,8 @@ cc_library( ":hlo_creation_utils", ":tuple_util", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -2880,6 +3252,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2895,6 +3268,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2922,6 +3297,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2955,13 +3332,13 @@ cc_library( cc_library( name = "source_map_util", - srcs = ["source_map_util.cc"], + srcs = [], hdrs = ["source_map_util.h"], deps = [ ":executable", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2976,6 +3353,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -3007,8 +3388,11 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3017,11 +3401,15 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", # fixdeps: keep + "@com_google_absl//absl/strings", ], ) @@ -3040,6 +3428,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 946ef6f0d6b9025b84c4b9341f4ec600465d4b1e..75dae7a7141647d7b7b60b0e07e11c143621ea63 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -22,13 +22,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -40,8 +47,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -122,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleIota(HloInstruction* instruction) override; + Status HandleConvolution(HloInstruction* convolution) override; Status HandleDivide(HloInstruction* divide) override; @@ -198,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -266,7 +273,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfConcat(HloInstruction* dot); StatusOr OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); StatusOr OptimizeDotOfGather(HloInstruction* dot); @@ -289,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr FoldConvInputPad(HloInstruction* convolution); + StatusOr FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -305,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -444,8 +460,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { Status AlgebraicSimplifierVisitor::HandleConcatenate( HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + absl::Span operands(concatenate->operands()); if (operands.size() == 1) { // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); @@ -521,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + HloInstruction::CreateConstant(literal.Clone())); } } @@ -540,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = MakeUnique( + Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -548,6 +563,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); } + + // If a literal is an increasing sequence from zero, replace it with an iota. + if (ShapeUtil::Rank(constant->shape()) == 1 && + ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsR1Iota()) { + return ReplaceWithNewInstruction( + constant, HloInstruction::CreateIota(constant->shape(), 0)); + } return Status::OK(); } @@ -575,7 +598,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { namespace { template Status InvertConstant(const HloInstruction& constant, Literal* result) { - return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return result->Populate([&](absl::Span indices) { return T{1.0} / constant.literal().Get(indices); }); } @@ -662,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + HloInstruction::CreateConstant((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -722,12 +745,25 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( } const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { return hlo; } - return computation_->AddInstruction( - HloInstruction::CreateReshape(dot->shape(), hlo)); + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + }; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + hlo = as_type(hlo, dot->shape().element_type()); + if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + hlo = computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + } + return hlo; + }; + + auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { + return AddReduce(as_type(hlo, F32), dim); }; auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, @@ -747,7 +783,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( if (ShapeUtil::Rank(rhs->shape()) == 1 && ShapeUtil::Rank(lhs->shape()) == 1) { TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), Flatten(rhs)), 0)))); return true; } @@ -781,17 +817,17 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, - reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(Flatten(lhs), rhs), 0)))); return true; } TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary( - AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), - rhs_collapsing_dim), - rhs), - rhs_collapsing_dim)))); + dot, reshape_if_necessary(add_reduce_in_f32( + multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); return true; } @@ -803,7 +839,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(AddReduce( + dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), lhs_collapsing_dim)), lhs_collapsing_dim)))); @@ -827,18 +863,18 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, - OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs, rhs_contracting_dim, /*swapped=*/false)); if (optimized_lhs_concat) { return optimized_lhs_concat; } - return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs, lhs_contracting_dim, /*swapped=*/true); } StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && lhs->concatenate_dimension() == lhs_contracting_dim && @@ -936,12 +972,13 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( new_dot_rhs = rhs_slice; } - auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + auto* new_dot = computation_->AddInstruction( + HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, + new_dot_dnums, dot.precision_config())); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_shape, HloOpcode::kAdd, add_result, new_dot)); + dot.shape(), HloOpcode::kAdd, add_result, new_dot)); } else { add_result = new_dot; } @@ -1037,9 +1074,11 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); - auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); - auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( - memoized_shape, left_operand, right_operand, dnums)); + auto memoized_shape = + ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); + auto* memoized_inst = computation_->AddInstruction( + HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, + dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1084,10 +1123,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if ((dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) || + ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || + ShapeUtil::Rank(dot->shape()) > 2) { return Status::OK(); } @@ -1135,8 +1176,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), - rhs->mutable_operand(0), lhs->mutable_operand(0), - dot_dimension_numbers)); + rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, + dot->precision_config())); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -1232,9 +1273,8 @@ namespace { // return value = {1, 3} // // Precondition: input_dim_indices is sorted. -std::pair> ReshapeLeavesDimensionsUnmodified( - const HloInstruction* hlo, - tensorflow::gtl::ArraySlice input_dim_indices) { +absl::optional> ReshapeLeavesDimensionsUnmodified( + const HloInstruction* hlo, absl::Span input_dim_indices) { CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); @@ -1252,11 +1292,11 @@ std::pair> ReshapeLeavesDimensionsUnmodified( } if (i >= unmodified_dims.size() || unmodified_dims[i].first != input_dim_index) { - return std::make_pair(false, std::vector()); + return absl::nullopt; } output_dim_indices.push_back(unmodified_dims[i].second); } - return std::make_pair(true, output_dim_indices); + return output_dim_indices; } // Returns true if the output of "instruction" is a permutation of the @@ -1385,6 +1425,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } + // broadcast(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateIota( + broadcast->shape(), + dims[Cast(operand)->iota_dimension()])); + } + // Merge two consecutive broadcasts into a single one. if (operand->opcode() == HloOpcode::kBroadcast) { std::vector new_dimensions; @@ -1439,6 +1488,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { + // iota -> zero if the iota dimension never produces an element other than + // zero. + auto* iota = Cast(instruction); + if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(iota->shape().element_type()).Clone())); + return ReplaceWithNewInstruction( + iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( @@ -1535,7 +1597,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - LiteralUtil::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1570,7 +1632,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -1705,16 +1767,33 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = reshape->shape(); + return ReplaceInstruction(reshape, operand); + } if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); - if (opt_dims.first) { + if (opt_dims.has_value()) { return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), - opt_dims.second)); + *opt_dims)); + } + } + + // reshape(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + auto* iota = Cast(operand); + auto opt_dims = + ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); + if (opt_dims.has_value()) { + CHECK_EQ(opt_dims->size(), 1); + return ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); } } @@ -1748,8 +1827,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } auto is_unstrided_slice = [](const HloInstruction* hlo) { - return c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { @@ -1803,9 +1882,15 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( } Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Most of those optimizations can be done for multi-output + // reduces. + if (ShapeUtil::IsTuple(reduce->shape())) { + return Status::OK(); + } + auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || ShapeUtil::IsZeroElementArray(reduce->shape())) { @@ -1920,7 +2005,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // This should make fusion easier or use less memory bandwidth in the unfused // case. if (arg->opcode() == HloOpcode::kConcatenate && - c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + absl::c_linear_search(reduce->dimensions(), + arg->concatenate_dimension())) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( @@ -1973,9 +2059,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() - << (convert != nullptr ? tensorflow::strings::StrCat( - "\nvia convert: ", convert->ToString()) - : ""); + << (convert != nullptr + ? absl::StrCat("\nvia convert: ", convert->ToString()) + : ""); // Do not fold interior padding into ReduceWindow since the backends do not // support it. @@ -1996,12 +2082,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (pad_literal == reduce_init_literal) { return true; } - auto converted_pad_literal = pad_literal.ConvertToShape( - reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + auto converted_pad_literal = + pad_literal.ConvertToShape(reduce_init_value->shape()); if (!converted_pad_literal.ok()) { return false; } - return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + return converted_pad_literal.ValueOrDie() == reduce_init_literal; }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. @@ -2138,6 +2224,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = transpose->shape(); + return ReplaceInstruction(transpose, operand); + } + if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); @@ -2146,40 +2237,157 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()) - .CloneToUnique())), - {})); + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { + return false; + } + + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); + } + } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} + +StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; + } + + const auto& padding = rhs->padding_config(); + + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); + + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; + } + + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } + + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; + } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} + +StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); const Shape& input_shape = lhs->shape(); const Shape& filter_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -2190,7 +2398,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2201,7 +2409,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2227,7 +2435,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2269,7 +2477,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2278,8 +2486,47 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, + convolution->precision_config())); + + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index c48196e861a559a5abfa360841ec70b39356fa2b..9f8d0ee88bdebcf17310cd0407b1b99e4b0a7b5f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloPassInterface { +class AlgebraicSimplifier : public HloModulePass { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform @@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; - tensorflow::StringPiece name() const override { return "algsimp"; } + absl::string_view name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 862cbeeba6b82e1f24a6616b3237dc47d022e9af..2047f894b465816eb97ba205e79033bd52bf7a0c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -18,11 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -34,13 +38,12 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -using ::testing::ElementsAre; namespace xla { namespace { +using ::testing::ElementsAre; + namespace op = xla::testing::opcode_matchers; AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { @@ -290,6 +293,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { EXPECT_THAT(root, op::Constant()); } +TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -513,7 +531,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1({1.f, 2.f, 3.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); @@ -1026,7 +1044,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_window_reversal(false); // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1428,6 +1447,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } +// Test transforming reshapes and transposes of rng. +TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + HloComputation::Builder builder(TestName()); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}), + RandomDistribution::RNG_UNIFORM, {zero, one})); + + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0})); + Shape reshape_shape = builder + .AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {4}), transpose)) + ->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // same shape as the reshape. + EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), + reshape_shape)); +} + // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { HloComputation::Builder builder(TestName()); @@ -1789,6 +1839,126 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { op::Reshape(op::Broadcast(param))); } +TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); + Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); + auto result_shape = iota->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + auto root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_EQ(Cast(computation->root_instruction()) + ->iota_dimension(), + 3); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + const int64 iota_dim = + Cast(computation->root_instruction()) + ->iota_dimension(); + EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1975,6 +2145,269 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); } +// Used for TEST_Ps that test merging (or not) of a kPad instruction into a +// convolution's Window. +struct ConvPaddingTestcase { + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window) + : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window, + /*pad_value=*/0) {} + + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window, float pad_value) + : padding(padding), + orig_conv_window(orig_conv_window), + expected_conv_window(expected_conv_window), + pad_value(pad_value) {} + + string ToString() const { + return absl::StrFormat( + "padding=%s, orig_conv_window=%s, expected_conv_window=%s, " + "pad_value=%f", + padding, orig_conv_window, expected_conv_window, pad_value); + } + + string padding; + string orig_conv_window; + string expected_conv_window; + float pad_value; +}; + +// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a +// computation that does +// +// conv(pad(param0, padding=padding), param1), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvInputPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P( + ConvInputPaddingTestCases, ConvInputPaddingTest, + ::testing::ValuesIn(std::vector{ + // Merge this edge padding into the conv. + {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"}, + // Merge this edge padding with the conv's edge padding. + {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"}, + // Merge this interior-padded kPad with the unpadded conv. The 3x6 + // interior padding gets transformed to 4x7 conv lhs dilation. + {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"}, + // kPad has dilation on one dim, conv has it on the other; merge them. + {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"}, + // kPad has dilation and edge padding on one dim, conv has them on the + // other; merge them. + {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10", + "pad=0_1x3_0 lhs_dilate=2x10"}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1}, + + // We refuse to transform the following because on some dimension, one + // of the kPad and conv has dilation and the other has some sort of + // padding. + {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + + // We can't merge feature or batch padding into the conv. + {"1_0x0_0x0_0x0_0", "", ""}, + {"0_0x1_0x0_0x0_0", "", ""}, + })); + +TEST_P(ConvInputPaddingTest, DoTest) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01 + "input")); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(input->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + input, pad_value, padding_config)); + + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, + ShapeUtil::MakeShape( + F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = + ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window)) + .ValueOrDie(); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + lhs_pad, filter, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrCat("size=3x3 ", testcase.expected_conv_window)); + } +} + +// ConvFilterPaddingTest (and its one associated TEST_P) checks that a +// computation that does +// +// conv(param0, pad(param1, padding=padding)), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvFilterPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P( + ConvFilterPaddingTestCases, ConvFilterPaddingTest, + ::testing::ValuesIn(std::vector{ + // Can only merge interior padding on the filter's spatial dimensions; + // all + // other paddings (edge padding and interior padding on the channel + // dims) + // should be rejected out of hand. + {"1_0_0x0_0_0x0_0x0_0", "", ""}, + {"0_1_0x0_0_0x0_0x0_0", "", ""}, + {"0_0_1x0_0_0x0_0x0_0", "", ""}, + {"0_0_0x1_0_0x0_0x0_0", "", ""}, + {"0_0_0x0_1_0x0_0x0_0", "", ""}, + {"0_0_0x0_0_1x0_0x0_0", "", ""}, + {"0_0_0x0_0_0x1_0x0_0", "", ""}, + {"0_0_0x0_0_0x0_1x0_0", "", ""}, + {"0_0_0x0_0_0x0_0x1_0", "", ""}, + {"0_0_0x0_0_0x0_0x0_1", "", ""}, + + // Interior padding on channel dims can be merged into the conv, so long + // as the conv and pad don't have interior padding on the same dim. + {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"}, + {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"}, + {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"}, + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"}, + + // Can't merge if for a given dim there's interior padding on both the + // pad and conv. + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1}, + })); + +TEST_P(ConvFilterPaddingTest, DoIt) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01 + "input")); + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(filter->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + filter, pad_value, padding_config)); + + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeShape( + F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = ParseWindow(absl::StrFormat("size=%dx%d %s", + rhs_pad->shape().dimensions(2), + rhs_pad->shape().dimensions(3), + testcase.orig_conv_window)) + .ValueOrDie(); + + // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place + // after the transformation. + PrecisionConfig precision_config; + precision_config.add_operand_precision(PrecisionConfig::HIGH); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + precision_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrFormat("size=%dx%d %s", + conv->operand(1)->shape().dimensions(2), + conv->operand(1)->shape().dimensions(3), + testcase.expected_conv_window)); + EXPECT_THAT(Cast(conv) + ->precision_config() + .operand_precision(), + ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); + } +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; @@ -2006,7 +2439,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options]() -> string { + auto build_and_simplify = [&]() -> string { HloComputation::Builder b(TestName()); Window window; @@ -2078,7 +2511,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto out_dims = in_dims; out_dims[in_channel_idx] = options.f_output_channels; - auto make_shape = [](tensorflow::gtl::ArraySlice dims, + auto make_shape = [](absl::Span dims, bool minor_to_major_layout) { if (minor_to_major_layout) { return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3}); @@ -2095,8 +2528,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { HloInstruction* filter = b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); - b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, - window, dnums)); + b.AddInstruction(HloInstruction::CreateConvolve( + out_shape, input, filter, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = HloTestBase::CreateNewModule(); @@ -2112,9 +2546,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { root->operand(0)->opcode() == HloOpcode::kDot) { auto lhs_shape = root->operand(0)->operand(0)->shape(); auto rhs_shape = root->operand(0)->operand(1)->shape(); - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", - tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); + return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ", + absl::StrJoin(rhs_shape.dimensions(), "x")); } return "UNEXPECTED CHANGE"; }; @@ -2475,7 +2908,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, + DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2498,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + Literal elements[] = {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector)}; + Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); @@ -2617,6 +3051,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1)); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast(root)->iota_dimension(), 2); +} + +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + HloComputation::Builder builder(TestName()); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1)); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast(root)->iota_dimension(), 2); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -2629,11 +3104,10 @@ struct PadReduceWindowEffectiveBroadcastCase { bool should_become_broadcast; string ToTestCaseName() const { - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(input_spatials, ","), ";", - tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", - tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, - ";", should_become_broadcast); + return absl::StrCat(absl::StrJoin(input_spatials, ","), ";", + absl::StrJoin(symmetric_pad_spatials, ","), ";", + absl::StrJoin(reduce_window_spatials, ","), ";", + prepend_a, ";", should_become_broadcast); } }; @@ -2651,8 +3125,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { // a and b are parallel bounds we can either turn into a B F S0 S1 or // `B S0 S1 F` kind of pattern. - auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice spatials, - int64 a, int64 b) { + auto decorate_spatials = [¶m](absl::Span spatials, int64 a, + int64 b) { std::vector result; if (param.prepend_a) { result.push_back(a); @@ -2759,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P( class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< - ::testing::tuple> {}; + ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { int m, k, n; bool transpose_lhs, transpose_rhs; - std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + PrimitiveType element_type; + std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam(); - Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); - Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); - Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k}); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -2787,8 +3262,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -2811,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), - ::testing::Bool())); + ::testing::Bool(), ::testing::Values(F32, BF16))); struct DotOfConcatTestSpec { int64 m; @@ -2863,8 +3338,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2927,8 +3402,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3045,8 +3520,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 dot_row_size = 1; int64 dot_col_size = spec.n; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3115,8 +3590,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 dot_row_size = spec.m; int64 dot_col_size = 1; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 51ebc4763b612884a4453edec5711f78c4006fc3..1ed6142dcecdc830cb7b8386e0cc20a2ea54aa7f 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -17,15 +17,15 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -69,8 +69,7 @@ StatusOr AllocationTracker::RegisterInternal( return InvalidArgument( "AllocationTracker for platform %s cannot register buffer from " "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer.platform()->Name().c_str()); + backend_->platform()->Name(), shaped_buffer.platform()->Name()); } } @@ -91,8 +90,9 @@ StatusOr AllocationTracker::RegisterInternal( // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer // into a regular ShapedBuffer, which is stored in // handle_to_shaped_buffers_. - handle_to_shaped_buffers_[handle].emplace_back(MakeUnique( - ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); + handle_to_shaped_buffers_[handle].emplace_back( + absl::make_unique( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } GlobalDataHandle result; @@ -124,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { // "handle does not exist". auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } for (auto& shaped_buffer : it->second) { @@ -143,7 +143,7 @@ StatusOr> AllocationTracker::DeconstructTuple( // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { - return InvalidArgument("global data handle %lld is not a tuple", + return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be @@ -200,14 +200,14 @@ StatusOr> AllocationTracker::ResolveInternal( VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } std::vector replicated_buffers; for (const auto& shaped_buffer : it->second) { if (shaped_buffer == nullptr) { - return InvalidArgument( - "global data handle %lld was previously deallocated", data.handle()); + return InvalidArgument("global data handle %d was previously deallocated", + data.handle()); } replicated_buffers.push_back(shaped_buffer.get()); } diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index d12be3e007fe0b16ac850d64521f0025d481b5d2..5c180cbdd492031e133b81149f0f4698619b7788 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -111,11 +112,11 @@ StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { return stream_pools_.at(executor).BorrowStream(executor); } -Backend::Backend( - se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, ComputationPlacer* computation_placer, - int intra_op_parallelism_threads) +Backend::Backend(se::Platform* platform, Compiler* compiler, + absl::Span stream_executors, + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -127,8 +128,8 @@ Backend::Backend( } } // Create a memory allocator for the valid stream executors. - memory_allocator_ = - MakeUnique(platform, stream_executors); + memory_allocator_ = absl::make_unique( + platform, stream_executors); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; @@ -176,7 +177,7 @@ StatusOr Backend::stream_executor( } } return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); + device_name(device_ordinal)); } StatusOr Backend::devices_equivalent(int device_ordinal_a, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 1bc3796fa48c1627538474d04ef5358ba64dfce9..a2dafbe803f8bd5f23e4e9f3f6d3e6f744c9fab9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -130,7 +130,7 @@ class Backend { // Return a string identifier for the given device, eg: "GPU:3". string device_name(int device_ordinal) const { - return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal); + return absl::StrCat(platform_->Name(), ":", device_ordinal); } // Returns true if the devices with the given ordinals are equivalent from @@ -149,7 +149,7 @@ class Backend { private: struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, + absl::Span stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index 2099916509acdbc2680cc2b5bd405e96f2f7bfb8..eda026ac5685dc469a6230094eb28b3618e36400 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -62,7 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, + batch_dot->precision_config())); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); @@ -76,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( return true; } -tensorflow::StringPiece BatchDotSimplification::name() const { +absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } @@ -84,10 +86,10 @@ StatusOr BatchDotSimplification::Run(HloModule* module) { bool changed = false; std::vector dot_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), - [](HloInstruction* instr) { - return instr->opcode() == HloOpcode::kDot; - }); + absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); } for (HloInstruction* dot_instr : dot_instrs) { TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index c0ca8d8ebac1a3b218e7bd4d6db02b69cfb6916f..5b625bf3b98b060531532f07de343f7ca4f09ac9 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -25,10 +25,10 @@ namespace xla { // Normally these would live in the algebraic simplifier, but we want to run // this to fixpoint (this pass reaches fixed point in one execution) before we // run the DotDecomposer. -class BatchDotSimplification : public HloPassInterface { +class BatchDotSimplification : public HloModulePass { public: StatusOr Run(HloModule* module) override; - tensorflow::StringPiece name() const override; + absl::string_view name() const override; private: StatusOr ElideDegenerateBatchDimensionFromBatchDot( diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index c4cd60c1201f7ddbf0aba4b6d587952531b74bfa..30d33e0d3531bb5e931ebfa0b60c91523dd0cb44 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -33,9 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -43,7 +43,7 @@ namespace xla { namespace { -using tensorflow::gtl::optional; +using absl::optional; // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. @@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( const Shape feature_shape = scale->shape(); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, computation_->AddInstruction( @@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add( @@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto elements_per_feature_literal = LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); + elements_per_feature_literal.Convert(ptype)); auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 7ae202c583516443a6263403fb5460d1adbabd97..147f3ae7b6d4ed0d4dadfb136e1e0f0bf3ae90c6 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -26,7 +26,7 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormExpander : public HloPassInterface { +class BatchNormExpander : public HloModulePass { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, @@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface { rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; - tensorflow::StringPiece name() const override { return "batchnorm_expander"; } + absl::string_view name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index a725351462809e5b670bbf1d79d2dded87e54f07..f7ac8f5482908af104554a1cf812370b9098cda7 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -29,15 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +using BatchNormExpanderTest = HloVerifiedTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -67,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -109,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -127,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + ParseAndVerifyModule(module_str); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); - for (auto* instruction : module->entry_computation()->instructions()) { + for (auto* instruction : module().entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 1b8b2d204503576c3fcb02f6d5b37f2db45e1768..d63287539dfde5bb4890ab8303ef2205133d8125 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index c9398387098fad84ba28735c30e426fedd9b0cb0..cb3d12f0bfd0e502136ce39660e091dc1c3879be 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -31,13 +31,13 @@ namespace xla { // optimization pipeline followed by a DCE pass. If other passes are needed // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the // changed made by this pass. -class BFloat16ConversionFolding : public HloPassInterface { +class BFloat16ConversionFolding : public HloModulePass { public: explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} ~BFloat16ConversionFolding() override = default; - tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + absl::string_view name() const override { return "bfloat16-fold"; } // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 7cf05ca443c00c3b40eeb7d756cf216b45c45c39..5f93740887aa7e61458990992fe0573883ff056d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { protected: + BFloat16ConversionFoldingTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); @@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -235,8 +239,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"", - /*all_reduce_id=*/tensorflow::gtl::nullopt)); + sum, /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 16e99b57220cc185fbfaa75d30a0de709cf61ee7..d5b1148058898596bfdb837826a590bbc74e202a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -15,12 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum and sort which can have a tuple - // output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleSort(HloInstruction* sort) override; - static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16NormalizationVisitor visitor(computation, bfloat16_support); @@ -73,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // Inserts conversion HLOs to replace the called computations' BF16 // operands/outputs to F32. Status ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps); + HloInstruction* hlo, absl::Span bf16_called_comps); HloComputation* computation_; const BFloat16Support* bfloat16_support_; @@ -118,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( } Status BFloat16NormalizationVisitor::ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps) { + HloInstruction* hlo, absl::Span bf16_called_comps) { std::map cloned_computations; for (auto& comp : bf16_called_comps) { auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); @@ -150,23 +144,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( return Status::OK(); } -Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape())) { - return HandleInstruction(crs); - } else { - return HandleMultipleOutputs(crs); - } -} - -Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return HandleInstruction(sort); - } else { - return HandleMultipleOutputs(sort); - } -} - Status BFloat16NormalizationVisitor::HandleMultipleOutputs( HloInstruction* hlo) { std::vector operand_types(hlo->operand_count()); @@ -380,6 +357,12 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + // TODO(b/112040122): Correctly normalize variadic reduce. + if ((hlo->opcode() == HloOpcode::kSort || + hlo->opcode() == HloOpcode::kCrossReplicaSum) && + ShapeUtil::IsTuple(hlo->shape())) { + return HandleMultipleOutputs(hlo); + } return HandleInstruction(hlo); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 2a60fe0af3218484acb95e6c69815d551350764c..f48e925823cf02bf4351b9bc7741123f5b1cd06f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -25,13 +25,13 @@ namespace xla { // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not // support BF16 input/output or mixed precision, according to the passed-in // backend-specific BF16 support rules. -class BFloat16Normalization : public HloPassInterface { +class BFloat16Normalization : public HloModulePass { public: explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} ~BFloat16Normalization() override = default; - tensorflow::StringPiece name() const override { return "bf16-normalization"; } + absl::string_view name() const override { return "bf16-normalization"; } // Run BF16 normalization on the given computation. Returns whether the // computation was changed. @@ -48,13 +48,13 @@ class BFloat16Normalization : public HloPassInterface { // use mixed precision; it removes mixed precision even if the backend supports // it. This pass is used to make the HLO module valid for other HLO passes which // do not support mixed precision. -class BFloat16MixedPrecisionRemoval : public HloPassInterface { +class BFloat16MixedPrecisionRemoval : public HloModulePass { public: BFloat16MixedPrecisionRemoval() {} ~BFloat16MixedPrecisionRemoval() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "bf16-mixed-precision-removal"; } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index f9f1f64998f5b925102dc238941897ff6d441b3f..cef0eba14e9dd463d6c32b047211bf25a84478f6 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,15 +68,20 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloTestBase { +class BFloat16NormalizationTest : public HloVerifiedTestBase { protected: + BFloat16NormalizationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16Normalization normalization(&bfloat16_support_); StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); - HloVerifier verifier(/*allow_mixed_precision=*/true); + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); EXPECT_IS_OK(verifier.Run(module).status()); return result.ValueOrDie(); @@ -104,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module.get())); + EXPECT_FALSE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -132,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -162,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -200,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -251,14 +256,14 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"", - /*all_reduce_id=*/tensorflow::gtl::nullopt)); + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -285,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -307,13 +312,16 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 2fb401c4289728f3f59538464c5b8ad49957985b..58f78f8e24d0bc00a63e3583828cf8e01ae4531a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters( HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { // Adjust parameters. CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { @@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN( - auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape(), - /*round_f32_to_bf16=*/true)); + TF_ASSIGN_OR_RETURN(auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 02b8cad089dd8465b7af5c1014e37b77ded6949d..6a62439f8877634a065979d1e2fcda262ca83dc1 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -58,15 +58,13 @@ namespace xla { // BFloat16ConversionFolding. If other passes are needed after this pass, run // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this // pass. -class BFloat16Propagation : public HloPassInterface { +class BFloat16Propagation : public HloModulePass { public: explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); ~BFloat16Propagation() override = default; - tensorflow::StringPiece name() const override { - return "bfloat16-propagation"; - } + absl::string_view name() const override { return "bfloat16-propagation"; } // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 69b654d30e42b1ed69304206f09120e86831d468..e032b5c624c0151fd63c870e0f21ec97656d625f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloTestBase { +class BFloat16PropagationTest : public HloVerifiedTestBase { protected: + BFloat16PropagationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. bool PropagatePrecision(HloModule* module) { @@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase { inst->users()[0]->opcode() == HloOpcode::kConvert && inst->users()[0]->shape().element_type() == BF16; } + + std::unique_ptr CreateDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + DefaultPrecisionConfig(2)); + } }; // Tests that BF16 can propagate through select over non-tuple buffers, but not @@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); - HloInstruction* root = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -150,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } @@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple0->shape(), tuple1, 0)), 0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); @@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { // Tests that a non-fusion computation's root should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* a = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); @@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add)); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); @@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); HloInstruction* b_f1 = builder_f1.AddInstruction( HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); - HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1)); auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); HloInstruction* add_f = builder_f.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); - HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + HloInstruction* dot_f = + builder_f.AddInstruction(CreateDot(shape, add_f, add_f)); auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { HloInstruction::CreateGetTupleElement(shape, fusion, 0)); HloInstruction* gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, fusion, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); auto body_param = builder_body.AddInstruction( HloInstruction::CreateParameter(0, shape, "body_param")); - auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, body_param, body_param)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_param, body_param)); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest, HloInstruction::CreateParameter(0, shape, "cond_param")); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1}, + {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest, auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // This add should prevent RHS from using BF16 auto cond_add_rhs = builder_cond.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot1 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); - auto body_dot2 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_dot1 = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); + auto body_dot2 = + builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs)); auto body_transpose = builder_body.AddInstruction( HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( @@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); auto rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond0_add_rhs = builder_cond0.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); - auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + auto cond0_dot = + builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); builder_cond0.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond1_add_lhs = builder_cond1.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); - auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + auto cond1_dot = + builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); builder_cond1.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. @@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); builder_body.AddInstruction( HloInstruction::CreateTuple({body_dot, body_rhs})); auto body = module->AddEmbeddedComputation(builder_body.Build()); @@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); - auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 1)))); - auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 1)))); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto lhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction::CreateGetTupleElement(shape, domain, 0)); HloInstruction* b_gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, domain, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); HloInstruction* b_trans = builder.AddInstruction( HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* dot = + builder.AddInstruction(CreateDot(shape, a_trans, b_trans)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 118a11c8de3c06d240079723f0a5db314cfcace5..34a7be0e9c079e9e42c28eef10af4079e99853b6 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,13 +22,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -36,20 +37,15 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { +namespace { +using absl::StrAppend; +using absl::StrAppendFormat; using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; -using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; - -namespace { template string ColocatedBufferSetsToString(const T& container, const char* title) { @@ -61,12 +57,65 @@ string ColocatedBufferSetsToString(const T& container, const char* title) { return result; } -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace + Status GatherComputationsByAllocationType( const HloModule* module, std::vector* thread_local_computations, @@ -107,7 +156,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s has conflicting allocation requirements (global " "and thread-local)", - computation->name().c_str()); + computation->name()); } if (is_thread_local) { @@ -130,7 +179,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s cannot contain call/while op because it " "requires thread-local buffer allocations", - computation->name().c_str()); + computation->name()); } worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. @@ -139,6 +188,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. @@ -146,9 +196,8 @@ Status GatherComputationsByAllocationType( true)); // Thread local. break; default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InternalError("Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode())); } } } @@ -168,65 +217,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); @@ -235,8 +225,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { } string BufferAllocation::Slice::ToString() const { - return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_, - ", size:", size_, "}"); + return absl::StrCat("{index:", index(), ", offset:", offset_, + ", size:", size_, "}"); } BufferAllocation::Slice BufferAllocation::GetSlice( @@ -297,7 +287,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); + StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); if (color().value() != 0) { StrAppend(&output, ", color ", color().value()); } @@ -329,11 +319,10 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - StrAppend(&output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, absl::StrFormat( + " %s [%d,%d]: %s\n", buffer->ToString(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()))); } return output; } @@ -426,7 +415,7 @@ StatusOr BufferAssignment::GetUniqueSlice( return FailedPrecondition( "BufferAllocation::Slice for instruction %s at index %s cannot " "be determined at compile-time.", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } } else { VLOG(3) << "No allocation"; @@ -435,7 +424,7 @@ StatusOr BufferAssignment::GetUniqueSlice( if (result.allocation() == nullptr) { return FailedPrecondition( "BufferAllocation::Slice not assigned for instruction %s at index %s", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } return result; } @@ -626,19 +615,25 @@ Status BufferAssignment::ComputeSummaryStats() { stats_.total_allocation_bytes += allocation.size(); } - // Only compute total fragmentation if all computations are sequential. - SequentialHloOrdering::HloModuleSequence module_sequence; + // Only compute total fragmentation if all computations have schedules. + HloSchedule schedule(module_); + bool schedule_complete = true; for (const auto& computation : module_->computations()) { - const std::vector* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence != nullptr) { - module_sequence.emplace(computation, *sequence); + if (!computation->IsFusionComputation()) { + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence == nullptr) { + schedule_complete = false; + } else { + schedule.set_sequence(computation, *sequence); + } } } - if (module_sequence.size() == module_->computation_count()) { + if (schedule_complete) { + TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -647,39 +642,38 @@ Status BufferAssignment::ComputeSummaryStats() { string BufferAssignment::Stats::ToString() const { string s; - Appendf(&s, "BufferAssignment stats:\n"); - Appendf(&s, " parameter allocation: %10s\n", - HumanReadableNumBytes(parameter_allocation_bytes).c_str()); - Appendf(&s, " constant allocation: %10s\n", - HumanReadableNumBytes(constant_allocation_bytes).c_str()); - Appendf(&s, " maybe_live_out allocation: %10s\n", - HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); - Appendf(&s, " preallocated temp allocation: %10s\n", - HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str()); + StrAppendFormat(&s, "BufferAssignment stats:\n"); + StrAppendFormat(&s, " parameter allocation: %10s\n", + HumanReadableNumBytes(parameter_allocation_bytes)); + StrAppendFormat(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes)); + StrAppendFormat(&s, " maybe_live_out allocation: %10s\n", + HumanReadableNumBytes(maybe_live_out_allocation_bytes)); + StrAppendFormat(&s, " preallocated temp allocation: %10s\n", + HumanReadableNumBytes(preallocated_temp_allocation_bytes)); if (preallocated_temp_fragmentation_bytes >= 0) { const double percent = 100. * preallocated_temp_fragmentation_bytes / preallocated_temp_allocation_bytes; - Appendf( + StrAppendFormat( &s, " preallocated temp fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(), - percent); + HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent); } - Appendf(&s, " total allocation: %10s\n", - HumanReadableNumBytes(total_allocation_bytes).c_str()); + StrAppendFormat(&s, " total allocation: %10s\n", + HumanReadableNumBytes(total_allocation_bytes)); if (total_fragmentation_bytes >= 0) { const double percent = 100. * total_fragmentation_bytes / total_allocation_bytes; - Appendf(&s, " total fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent); + StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(total_fragmentation_bytes), percent); } return s; } string BufferAssignment::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); + absl::StrAppend(&output, "BufferAssignment:\n"); for (auto& allocation : allocations_) { - tensorflow::strings::StrAppend(&output, allocation.ToString()); + absl::StrAppend(&output, allocation.ToString()); } return output; } @@ -1070,12 +1064,25 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + + // Returns a heap algorithm that chooses the best result from several + // algorithms. + auto get_heap_algorithm = [&](int64 alignment) { + auto algorithms = + absl::make_unique>>(); + algorithms->push_back(absl::make_unique( + absl::make_unique(alignment))); + algorithms->push_back( + absl::make_unique(alignment)); + return absl::make_unique(std::move(algorithms)); + }; + if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(&assignment->module()); FlatSet all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; @@ -1083,7 +1090,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - module_sequence[computation] = *instruction_sequence; + schedule.set_sequence(computation, *instruction_sequence); all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1099,9 +1106,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), - assignment->module(), module_sequence, + HeapSimulator::Run(get_heap_algorithm(alignment), + assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1129,9 +1135,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), - *computation, *instruction_sequence, + HeapSimulator::Run(get_heap_algorithm(alignment), *computation, + HloInstructionSequence(*instruction_sequence), assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1645,7 +1650,8 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - // Can't use MakeUnique because BufferAssignment constructor is private. + // Can't use absl::make_unique because BufferAssignment constructor is + // private. std::unique_ptr assignment( new BufferAssignment(module, std::move(liveness), std::move(buffer_size), std::move(color_alignment))); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 94495290c131e22392079dc2d0237d990b646d3e..24ba7c16f548c10f58f41d2b88488939ca2d8e4d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" @@ -41,6 +41,17 @@ limitations under the License. namespace xla { +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations); + // This class abstracts an allocation of contiguous memory which can hold the // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range // of the allocation, represented by a Slice. A single BufferAllocation may hold diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index eccb146a0d7d628870be179a540d9750df3fe41c..795beb9ff5ceb2998a85fbd03d8bb1d3b2febc12 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -30,16 +30,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -79,15 +81,14 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloTestBase { +class BufferAssignmentTest : public HloVerifiedTestBase { protected: - BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -98,7 +99,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentNoBuffersForConstants( HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -109,7 +110,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -119,15 +120,12 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentWithInstructionSequence( HloModule* module, - tensorflow::gtl::ArraySlice instruction_sequence, + absl::Span instruction_sequence, int64 alignment = 1) { - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[module->entry_computation()] = - std::vector(instruction_sequence.begin(), - instruction_sequence.end()); + HloSchedule schedule(module); + schedule.set_sequence(module->entry_computation(), instruction_sequence); return BufferAssigner::Run( - module, - xla::MakeUnique(module, module_sequence), + module, absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -147,6 +145,17 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildReduceComputation(const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2)); + return builder.Build(); + } + // Builds a simple compare-to-limit (x < 4) computation for a While. // // condition: @@ -163,8 +172,8 @@ class BufferAssignmentTest : public HloTestBase { HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4)); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4)); return builder.Build(); } @@ -311,12 +320,12 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -335,13 +344,13 @@ TEST_F(BufferAssignmentTest, BufferForConst) { module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -363,7 +372,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -386,7 +395,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -400,12 +409,14 @@ TEST_F(BufferAssignmentTest, Basic) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -413,7 +424,7 @@ TEST_F(BufferAssignmentTest, Basic) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -447,12 +458,14 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -472,7 +485,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module.get(), colorer); + auto buffers = RunColoredBufferAssignment(module, colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -506,12 +519,14 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -539,7 +554,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module.get(), colorer); + auto buffers = RunColoredBufferAssignment(module, colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -576,12 +591,14 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( @@ -589,7 +606,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -640,7 +657,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -675,10 +692,10 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // output. (Reuse is not safe in the general case, as it reshapes and some // out-of-order reductions could overwrite an element before a use.) // - // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) + // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) auto module = CreateNewModule(); auto reduce_computation = - module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( @@ -699,7 +716,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -755,7 +772,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -820,7 +837,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -858,7 +875,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -887,7 +904,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -920,7 +937,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -957,7 +974,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -992,7 +1009,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1024,7 +1041,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1062,7 +1079,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1106,7 +1123,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Allocations for the map computation should be thread-local and not // live-out. @@ -1155,7 +1172,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1191,7 +1208,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1228,13 +1245,14 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); + Literal elements[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1248,7 +1266,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { /*operands=*/{}, /*custom_call_target=*/"foo_function")); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1279,7 +1297,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1341,7 +1359,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1377,7 +1395,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1404,7 +1422,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1442,7 +1460,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1471,17 +1489,20 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_ab = builder.AddInstruction( - HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); - auto dot_bc = builder.AddInstruction( - HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( + shape_2x4, param_a, param_b, dot_dnums, precision_config)); + auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( + shape_3x4, param_b, param_c, dot_dnums, precision_config)); builder.AddInstruction( - HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); + HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); + auto assignment = RunBufferAssignment(module, /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1500,7 +1521,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module.get(), /*alignment=*/64); + assignment = RunBufferAssignment(module, /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1531,12 +1552,14 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -1544,16 +1567,13 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); - // Trivially, the set of peak memory logical buffer(s) of an allocation with a - // single logical buffer should be exactly the logical buffer in that - // allocation. const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = mul_buffer.PeakMemoryLogicalBuffers(); ASSERT_EQ(peak_buffers.size(), 1); - EXPECT_EQ(peak_buffers[0]->instruction(), mul); + EXPECT_EQ(peak_buffers[0]->instruction(), broadcast); } TEST_F(BufferAssignmentTest, PeakBuffers) { @@ -1589,7 +1609,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module.get(), {param, log, rev, neg, concat, root}); + module, {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1645,7 +1665,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1695,15 +1715,13 @@ ENTRY main { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); - + ParseAndVerifyModule(hlo_text); HloInstruction* constant_1 = - module->entry_computation()->GetInstructionWithName("constant.1.1"); + module().entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module->entry_computation()->GetInstructionWithName("constant.1.2"); + module().entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(&module()); { const BufferAllocation& allocation_for_const_1 = @@ -1732,7 +1750,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloTestBase { +class WhileBufferAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1766,10 +1784,10 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, xla::MakeUnique(module, sequence), + module, absl::make_unique(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1805,9 +1823,9 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1831,8 +1849,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1888,20 +1906,20 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); + ParseAndVerifyModule(module_str); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module->instruction_count(); + int64 instruction_count = module().instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); - ASSERT_EQ(instruction_count, module->instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(&module()).status()); + ASSERT_EQ(instruction_count, module().instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* bcast = + module().entry_computation()->root_instruction(); const HloInstruction* param = - module->entry_computation()->parameter_instruction(0); + module().entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1909,7 +1927,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(&module()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -1956,20 +1974,20 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); + ParseAndVerifyModule(module_str); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module->instruction_count(); + int64 instruction_count = module().instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); - ASSERT_EQ(instruction_count, module->instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(&module()).status()); + ASSERT_EQ(instruction_count, module().instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* bcast = + module().entry_computation()->root_instruction(); const HloInstruction* constant = - module->entry_computation()->GetInstructionWithName("constant.42"); + module().entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1977,7 +1995,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(&module()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2070,24 +2088,31 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_IS_OK(copy_insertion.Run(module).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = { - token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + schedule.set_sequence( + module->entry_computation(), + {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); + TF_ASSERT_OK(schedule.Verify()); + TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), sequence), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2120,7 +2145,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2141,8 +2166,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2184,13 +2209,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); } - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2214,15 +2239,14 @@ ENTRY Main { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString( - hlo_text, legacy_flags::GetDebugOptionsFromFlags())); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + ParseAndVerifyModule(hlo_text, config); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(&module()); - HloComputation* main = module->entry_computation(); - HloComputation* callee = module->GetComputationWithName("Callee"); + HloComputation* main = module().entry_computation(); + HloComputation* callee = module().GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2245,29 +2269,6 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } -static bool IsPostOrderTraversal( - const std::vector& sequence) { - tensorflow::gtl::FlatSet seen_so_far; - auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { - return seen_so_far.count(instruction) == 0; - }; - - for (auto instruction : sequence) { - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), has_not_been_seen_yet) || - std::any_of(instruction->control_predecessors().begin(), - instruction->control_predecessors().end(), - has_not_been_seen_yet)) { - return false; // Not a post order. - } - if (!seen_so_far.insert(instruction).second) { - return false; // Not a "traversal". - } - } - - return true; -} - TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2282,14 +2283,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto weights0 = builder.AddInstruction( HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto input1 = builder.AddInstruction( HloInstruction::CreateParameter(2, data_shape_, "input1")); auto weights1 = builder.AddInstruction( HloInstruction::CreateParameter(3, data_shape_, "weights1")); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, one, {1})); + HloInstruction::CreateBroadcast(data_shape_, one, {})); auto cond = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2309,41 +2310,40 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); auto gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); - auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, gte0, gte1)); + auto root_add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1)); module->AddEntryComputation(builder.Build()); { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); EXPECT_TRUE(result); } - RunCopyInsertion(module.get()); + RunCopyInsertion(module); - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo sequence for the + // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + schedule.set_sequence(module->entry_computation(), + {input1, weights1, one, output1, while1->operand(0), + while1, input0, weights0, zero, output0, + while0->operand(0), while0, gte0, gte1, root_add}); - // If this ASSERT_TRUE fails, we constructed a bogus sequence above - // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + // If this ASSERT fails, we constructed a bogus sequence above and this test + // itself is buggy. + TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run(module, + absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); @@ -2361,9 +2361,9 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2394,8 +2394,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 810d597e730c1823668c81598df6138655e58b55..9b2783a214a686f3148723d19bbc94421fc8b4e4 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -75,27 +75,25 @@ Status BufferLiveness::Analyze() { string BufferLiveness::ToString() const { std::vector pieces; - pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):", - module_->name().c_str())); + pieces.push_back( + absl::StrFormat("BufferLiveness(module=%s):", module_->name())); pieces.push_back("HloOrdering:"); pieces.push_back(hlo_ordering_->ToString()); - pieces.push_back(tensorflow::strings::Printf("Aliased buffers:")); + pieces.push_back("Aliased buffers:"); for (const LogicalBuffer* buffer : aliased_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - pieces.push_back(tensorflow::strings::Printf("Live out buffers:")); + pieces.push_back("Live out buffers:"); for (const LogicalBuffer* buffer : maybe_live_out_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b)); if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) { return false; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 4a927b57674345f8b3493c098778182a299c5902..17e50905059ad2c92784d14132c1cb1f46f35ade 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -18,14 +18,16 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -119,8 +121,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -165,11 +167,11 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), sequence)) + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at @@ -215,8 +217,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -249,8 +251,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -290,12 +292,11 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, negate, exp, add}; - module_sequence.emplace(computation, order); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, negate, exp, add}); auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -338,13 +339,13 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, add, recv, - recv_done, send, send_done}; - module_sequence.emplace(computation, order); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {param, add, token, recv, recv_done, send, send_done}); + TF_ASSERT_OK(schedule.Verify()); auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); @@ -376,8 +377,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -412,8 +413,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -439,22 +440,22 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + Literal elements0[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; + auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]}); + Literal element1 = LiteralUtil::CreateR0(3); + auto inner_tuple1 = LiteralUtil::MakeTuple({&element1}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inner_tuple0->shape(), tuple_constant, 0)); + inner_tuple0.shape(), tuple_constant, 0)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -518,8 +519,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -580,8 +581,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - // Runs BufferLiveness on this computation. - // Returns whether buffer interference is detected between tuple-shaped - // parameter and root instructions at tuple element 1. - bool Run(const bool update_uses_tuple_element1, - const bool fuse_gte0 = false) { + std::unique_ptr BuildModule(const bool update_uses_tuple_element1, + const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); // Create output tuple. - auto tuple_root = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = CreateNewModule(); - module->AddEntryComputation(BuildDummyComputation()); - auto* computation = module->AddEmbeddedComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); + auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. if (update_uses_tuple_element1) { computation->CreateFusionInstruction( @@ -666,16 +664,39 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { computation->CreateFusionInstruction({gte0}, HloInstruction::FusionKind::kLoop); } + return module; + } + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } + bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); + // Run BufferLiveness on 'module'. + auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie(); + auto hlo_ordering = absl::make_unique(module.get()); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); + return hlo_ordering->MayInterfere( + dataflow->GetUniqueValueAt(tuple_param0, {1}), + dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow); + } }; // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); + EXPECT_FALSE( + RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases @@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); + EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false, + /*fuse_gte0=*/true)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); + EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true)); } class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { @@ -780,10 +806,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index 2bc556a9e270136f5f3eaf2433f8c96eeeaea0a2..fdf822c666b15afbc7553ca89d4f92ab08201869 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -17,11 +17,10 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index f4be16e0843f64f41ef27539bf263ae98ce0ebf9..69b36463560a1fad4f62687e9014fb3fbe5bbd13 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a23427f00ccd88bb0fe1d973a667f80ca54b14cd..23b2a327096dfdb3c756a4acc5476ec01dcac1b3 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,21 +17,21 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::StrCat; +using absl::StrAppendFormat; +using absl::StrCat; string CallContextToString(CallContext context) { switch (context) { @@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: return CallContext::kParallel; @@ -70,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { } string CallSite::ToString() const { - return StrCat(instruction()->name(), " calls in context ", - CallContextToString(context()), ": ", - tensorflow::str_util::Join( - called_computations(), ", ", + return StrCat( + instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + absl::StrJoin(called_computations(), ", ", [](string* out, const HloComputation* computation) { out->append(computation->name()); })); @@ -236,8 +237,8 @@ void CallGraph::SetCallContexts() { /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { - // Constructor for CallGraph is private so MakeUnique can't be used. - auto call_graph = WrapUnique(new CallGraph(module)); + // Constructor for CallGraph is private so absl::make_unique can't be used. + auto call_graph = absl::WrapUnique(new CallGraph(module)); VLOG(2) << "Building call graph for:"; XLA_VLOG_LINES(2, module->ToString()); @@ -355,20 +356,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, string CallGraph::ToString() const { string out; - Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + StrAppendFormat(&out, "Call graph for module %s:\n", module_->name()); for (const CallGraphNode& node : nodes()) { - Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); - Appendf(&out, " calls:\n"); + StrAppendFormat(&out, "Computation %s:\n", node.computation()->name()); + StrAppendFormat(&out, " calls:\n"); for (const HloComputation* callee : node.callees()) { - Appendf(&out, " %s\n", callee->name().c_str()); + StrAppendFormat(&out, " %s\n", callee->name()); } - Appendf(&out, " called by:\n"); + StrAppendFormat(&out, " called by:\n"); for (const HloComputation* caller : node.callers()) { - Appendf(&out, " %s\n", caller->name().c_str()); + StrAppendFormat(&out, " %s\n", caller->name()); } - Appendf(&out, " callsites:\n"); + StrAppendFormat(&out, " callsites:\n"); for (const CallSite& callsite : node.callsites()) { - Appendf(&out, " %s\n", callsite.ToString().c_str()); + StrAppendFormat(&out, " %s\n", callsite.ToString()); } } return out; diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 97d3811508adee1bf2d0942bcc69e3e34a41c8c3..3af2ab5edfd9faf4ac5193df4b823c21b55b2f7f 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -15,8 +15,8 @@ limitations under the License. // Call graph for an HLO module. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ #include @@ -272,4 +272,4 @@ class CallGraph { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index cc80b7484313329104eec1ce71a150b47d8330c9..34f3f914d593bc603c4964663f9cafb70a136fd3 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloTestBase { +class CallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(3, call_graph->nodes().size()); @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // Test visitation of only reachable nodes. { @@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. auto module = CreateNewModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 256d05a73e0bf61d959d21795c106286b52d0b19..1d4214044409ae06239506e610000c839450a030 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( "Could not find mapping from subcomputation HLO %s to a cloned HLO.", - subcomputation_hlo->ToString().c_str()); + subcomputation_hlo->ToString()); } return it->second; } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index a8345a394d46c90a48305313dac0bcd9b06938ac..08c4aff4f7fc7fc332fc7f34ece019eb57d71f3a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ #include @@ -25,7 +25,7 @@ namespace xla { // For every kCall operation in the main computation, we inline the body of the // called function, and proceed recursively. -class CallInliner : public HloPassInterface { +class CallInliner : public HloModulePass { public: using InlinedInstructionMap = std::unordered_map; @@ -35,11 +35,11 @@ class CallInliner : public HloPassInterface { static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; - tensorflow::StringPiece name() const override { return "CallInliner"; } + absl::string_view name() const override { return "CallInliner"; } StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index ff968bca297077c7cf869ff8d2becb8bf739dce3..e6b566543594a86eb5369ee9b7440f62618f6c5a 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace op = xla::testing::opcode_matchers; @@ -41,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloTestBase; +using CallInlinerTest = HloVerifiedTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -65,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -92,6 +91,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { module->AddEmbeddedComputation(just_false.Build()); HloComputation::Builder call_false_builder(TestName() + ".call_false"); + call_false_builder.AddInstruction( + HloInstruction::CreateParameter(0, pred, "param")); call_false_builder.AddInstruction( HloInstruction::CreateCall(pred, {}, false_computation)); HloComputation* call_false = @@ -106,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -162,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 13008efed1494402eaff47904c2e4797334381a1..3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::HOST_TO_DEVICE) { return FailedPrecondition( "host-to-device channels cannot be used with a Send operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " - "by a sender: %lld", + "by a sender: %d", handle.handle()); } channel.has_sender = true; @@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::DEVICE_TO_HOST) { return FailedPrecondition( "device-to-host channels cannot be used with a Recv operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } @@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (channel.receiver_count >= 1) { return FailedPrecondition( "when registering recv, passed a channel handle that is already used " - "by a receiver: %lld", + "by a receiver: %d", handle.handle()); } channel.receiver_count += 1; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index d773558c284a7d645f2766bb88c50f7da3777e5d..52037bf9b52556c6aa2e66dd3209e25cf085cfe3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 7426672a7a2a9102bd5ea98bd51092982e1e09b4..96bd2616f5607de888a096f8392ceb68490276e3 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector> hlo_modules; @@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime( if (!directory_path.empty()) { HloSnapshot hlo_snapshot; *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = tensorflow::strings::StrCat( - "computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); + string filename = + absl::StrCat("computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); @@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 1ac950bdd66bd034dfdafa8598ec506221e99c2f..61136a3e11fe15fb74eac257f46292c6cd24ce7d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -50,12 +50,12 @@ class CompileOnlyService : public Service { // |CompileOnlyClient::CompileAheadOfTime| for additional details. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options); StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6b3b9820f09803c8a04504e6c35c22de51abf04b..687ecafe0c308ecc22857fae650c6998677f605d 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() { return NotFound( "could not find registered compiler for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 99abb9bae32b35652e84cddc7c38dbd97ecb5006..1fdda31c34a17a16f75e1efada542c2c2ea15038 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -48,11 +48,6 @@ namespace xla { // compuation. using ObjectFileData = std::vector; -// Contains the buffer sizes information needed to allocate buffers to execute -// an ahead-of-time computation. Entries which contain -1 designate a parameter -// which should be skipped over during allocation. -using BufferSizes = std::vector; - // Abstract superclass describing the result of an ahead-of-time compilation. class AotCompilationResult { public: diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index cb61f3da39fb8eef69fd81066d87a1da91a62935..af8f7f1027a40703137d6880a9865449c560a47b 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -52,9 +52,8 @@ string ComputationLayout::ToString() const { for (auto& param_layout : parameter_layouts_) { params.push_back(param_layout.ToString()); } - return tensorflow::strings::StrCat("(", - tensorflow::str_util::Join(params, ", "), - ") => ", result_layout_.ToString()); + return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ", + result_layout_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 187ce568cbb6c6666e978b8c8114262313c70ba5..2210a8578ad73efb27dc9c230b142c55228d2af5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,8 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -29,12 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; namespace xla { @@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { "computation_count=%d", proto.replica_count(), proto.computation_count()); } - auto assignment = MakeUnique(proto.replica_count(), - proto.computation_count()); + auto assignment = absl::make_unique( + proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); ++computation) { const auto& computation_device = proto.computation_devices(computation); @@ -132,7 +132,7 @@ StatusOr ComputationPlacer::AssignDevices( return NotFound( "could not find registered computation placer for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.placer == nullptr) { @@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() { } // namespace xla static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index b7be3ba605a89a736b032eaab5a5085ac64fc549..4ea3a13f2835c5fef99c274f14d7d683c9ff5fc8 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -28,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 063261e26d06e21a297e8e3c405898a17221b7ca..2223ad67534dc31fc2c56ce68bdc87e881f20f32 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -16,20 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { // HLO pass that removes kConditional with a constant predicate, replacing them // with their true or false computation as appropriate. -class ConditionalSimplifier : public HloPassInterface { +class ConditionalSimplifier : public HloModulePass { public: - tensorflow::StringPiece name() const override { - return "simplify-conditional"; - } + absl::string_view name() const override { return "simplify-conditional"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2 --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -0,0 +1,249 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// ConvolutionVisitor traverses the HLO computation and rewrites Convolution +// operations with feature_group_count > 1 into convolutions with +// feature_group_count = 1. +class ConvolutionVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* convolution) override; + + // Runs the visitor on a computation. + static bool Run(HloComputation* computation); + + // Returns whether any convolution ops were rewritten. + const bool changed() const { return changed_; } + + ~ConvolutionVisitor() override = default; + + private: + explicit ConvolutionVisitor(HloComputation* computation) + : computation_(computation) {} + + // Current HloComputation instance the ConvolutionVisitor is traversing. + HloComputation* computation_; + + // Whether rewrite has occurred. + bool changed_ = false; +}; + +bool ConvolutionVisitor::Run(HloComputation* computation) { + ConvolutionVisitor visitor(computation); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; +} + +Shape ExpandedFilterShape(const Shape& shape, int64 group_count, + int64 input_feature_dim) { + int64 num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); + Shape expanded_shape = shape; + expanded_shape.set_dimensions( + input_feature_dim, shape.dimensions(input_feature_dim) * group_count); + return expanded_shape; +} + +// Returns a vector with 'group_count' many groups, where the i-th group +// consists of 'group_size' times the value i. +std::vector GetMaskIds(int64 group_size, int64 group_count) { + std::vector values; + for (int i = 0; i < group_count; ++i) { + for (int j = 0; j < group_size; ++j) { + values.push_back(i); + } + } + return values; +} + +// Create a mask for grouped convolution that will make a normal convolution +// produce the same results as a grouped convolution. For a [2, 1, 6] +// filter this returns a [2, 3, 6] mask +// 1 1 0 0 0 0 +// 0 0 1 1 0 0 +// 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 +// 0 0 1 1 0 0 +// 0 0 0 0 1 1 +// +// The first step is to create a rank 1 constant: +// 0 1 2 +// +// This is broadcasted to +// 0 0 0 0 0 0 +// 1 1 1 1 1 1 +// 2 2 2 2 2 2 +// +// 0 0 0 0 0 0 +// 1 1 1 1 1 1 +// 2 2 2 2 2 2 +// +// Then we create another rank 1 constant +// 0 0 1 1 2 2 +// +// This is broadcasted to +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// +// Finally we use the Eq op of these two broadcasted constants and get the +// desired mask. +HloInstruction* GetExpandedFilterMask( + const Shape& filter_shape, int64 input_feature_dim, + int64 output_feature_dim, int64 group_count, + const std::function)>& + add_instruction) { + Shape expanded_filter_shape = + ExpandedFilterShape(filter_shape, group_count, input_feature_dim); + Shape mask_shape = ShapeUtil::MakeShape( + S32, AsInt64Slice(expanded_filter_shape.dimensions())); + int64 output_feature = filter_shape.dimensions(output_feature_dim); + int64 group_size = filter_shape.dimensions(input_feature_dim); + + // Create a 'input_feature' sized linspace and 'output_feature' sized linspace + // that will be broadcasted into perpendicular dimensions and compared. + const std::vector input_feature_filter_mask = + GetMaskIds(group_size, group_count); + const std::vector output_feature_filter_mask = + GetMaskIds(output_feature / group_count, group_count); + + auto mask1 = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(input_feature_filter_mask))); + auto broadcasted_mask1 = add_instruction( + HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim})); + auto mask2 = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(output_feature_filter_mask))); + auto broadcasted_mask2 = add_instruction( + HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim})); + + // Compare the broadcasted output feature linspace to the input feature + // linspace to create a diagonal predicate. + Shape predicate_shape = ShapeUtil::MakeShape( + PRED, AsInt64Slice(expanded_filter_shape.dimensions())); + return add_instruction(HloInstruction::CreateBinary( + predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + int64 group_count = convolution->feature_group_count(); + if (group_count == 1) { + return Status::OK(); + } + auto filter = convolution->mutable_operand(1); + changed_ = true; + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + + auto dim_numbers = convolution->convolution_dimension_numbers(); + int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension(); + int64 group_size = filter->shape().dimensions(input_feature_dim); + int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension(); + auto expanded_filter_shape = + ExpandedFilterShape(filter->shape(), group_count, input_feature_dim); + HloInstruction* filter_mask = GetExpandedFilterMask( + filter->shape(), input_feature_dim, output_feature_dim, group_count, add); + HloInstruction* expanded_filter; + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. + if (group_size == 1) { + Shape reshaped_filter_shape = + ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); + auto reshaped_filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + std::vector broadcast_dims; + for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { + if (i == input_feature_dim) { + continue; + } + broadcast_dims.push_back(i); + } + expanded_filter = add(HloInstruction::CreateBroadcast( + expanded_filter_shape, reshaped_filter, broadcast_dims)); + } else { + // We could possibly also use reshape, broadcast, reshape instead of concat + // here, but it would require more complex code, and for depthwise + // convolution we would never end up in this branch. + std::vector concat_operands(group_count, filter); + expanded_filter = add(HloInstruction::CreateConcatenate( + expanded_filter_shape, concat_operands, input_feature_dim)); + } + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + auto new_filter = add( + HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect, + filter_mask, expanded_filter, zero_filter)); + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), new_filter, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); + return Status::OK(); +} + +} // namespace + +StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" + + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (ConvolutionVisitor::Run(comp)) { + changed = true; + } + } + XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" + + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..ce0138e56fbd51daaf5d3ac329ccbe31a9fdbde7 --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +// A pass which rewrites convolutions with feature_group_count > 1 into +// convolutions with feature_group_count = 1. +class ConvolutionFeatureGroupConverter : public HloModulePass { + public: + ConvolutionFeatureGroupConverter() {} + + absl::string_view name() const override { + return "convolution-feature-group-converter"; + } + + // Run convolution rewriting on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..28373ebf636c7b6b3059dcf6cd931901ebc87fc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using ConvolutionFeatureGroupConverterTest = HloTestBase; +namespace op = testing::opcode_matchers; + +TEST_F(ConvolutionFeatureGroupConverterTest, + ConvertFeatureGroupCountEqualToInputFeatureDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2,2] { + %input = f32[1,2,2]{2,1,0} parameter(0) + %copy = f32[1,2,2]{2,0,1} copy(f32[1,2,2]{2,1,0} %input) + %filter = f32[1,1,2]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + ConvolutionFeatureGroupConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with feature_group_count = 1. + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->feature_group_count(), 1); + // Verify that the filter operand has been replaced. + EXPECT_THAT(root->operand(1), + op::Select(op::Eq(op::Broadcast(op::Constant()), + op::Broadcast(op::Constant())), + op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Constant()))); +} + +TEST_F(ConvolutionFeatureGroupConverterTest, + ConvertFeatureGroupCountDivisorOfInputFeatureDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2,2] { + %input = f32[1,2,4]{2,1,0} parameter(0) + %copy = f32[1,2,4]{2,0,1} copy(f32[1,2,4]{2,1,0} %input) + %filter = f32[1,2,2]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + ConvolutionFeatureGroupConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with feature_group_count = 1. + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->feature_group_count(), 1); + // Verify that the filter operand has been replaced. + EXPECT_THAT(root->operand(1), + op::Select(op::Eq(op::Broadcast(op::Constant()), + op::Broadcast(op::Constant())), + // We expect to see Concatenate here instead of + // Broadcast, because feature_group_count < input + // feature dimension. + op::Concatenate(op::Parameter(), op::Parameter()), + op::Broadcast(op::Constant()))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 36fb9b43aa20bad788a0638b4fed6c88fc9023f0..b65dfef9c9575b683b2656af2ccc151d87db2cd7 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -31,18 +33,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { - -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { +using absl::StrAppend; + bool IsEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && @@ -312,7 +309,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computaiton roots, +// We add copies for all the indices of the true and false computation roots, // in order to resolve interference. We later rely on the CopyRemover to drop // the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, @@ -381,7 +378,7 @@ class CopyRemover { } string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); StrAppend(&out, " Buffer values, in dependency order:\n"); for (const HloBuffer& buffer : alias_analysis_.buffers()) { StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); @@ -482,7 +479,7 @@ class CopyRemover { // 'values' an entry is created in value_to_node which indicates the // respective ValueNode representing that value. void AddValueList( - tensorflow::gtl::ArraySlice values, + absl::Span values, tensorflow::gtl::FlatMap* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; @@ -648,7 +645,12 @@ class CopyRemover { // We can only perform copy elision if the resulting merged values have // totally ordered live ranges; otherwise the merged buffer would have // live range interference. - if (IsHead(*dest)) { + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { // The copy copies an arbitrary value in the source buffer (call it s_x) // and defines d_0, the first value in the destination buffer. After // merging, the values in the combined buffer must be strictly ordered @@ -858,16 +860,16 @@ class CopyRemover { for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { values.push_back(p->value); } - return StrCat("{", - Join(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); } string ToString() const { - string out = StrCat("BufferValueTracker:\n"); + string out = absl::StrCat("BufferValueTracker:\n"); StrAppend(&out, " Def-use chains in each buffer:\n"); for (const ValueNode* head : value_lists_) { StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), @@ -875,10 +877,10 @@ class CopyRemover { const ValueNode* p = head; do { StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), "\n"); p = p->next; @@ -955,16 +957,11 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { return Status::OK(); } -// Add copies to address special constraints on the roots of computations not -// related to live range interference: -// -// (1) Entry computation root must be unambiguous and distinct. -// -// (2) Any computation called by a kCall instruction must have an -// unambiguous root. -// -// (3) Constants and parameters cannot be live out of the entry computation -// +Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + return AddSpecialCaseCopies(*call_graph, module); +} + Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, @@ -1060,15 +1057,6 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } - // Special case copies are not eligible for later copy elision passes. - indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { - if (has_copy) { - HloInstruction* copy = *copies_added.mutable_element(index); - if (copy != nullptr) { - copy->SetCopyElisionAllowed(false); - } - } - }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1076,10 +1064,10 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, return Status::OK(); } -Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) { +Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - DependencyHloOrdering ordering(module); TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); return Status::OK(); } @@ -1096,8 +1084,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, std::unique_ptr call_graph = CallGraph::Build(module); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - instruction->CopyElisionAllowed()) { + if (instruction->opcode() == HloOpcode::kCopy) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } } @@ -1163,10 +1150,10 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + DependencyHloOrdering dep_ordering(module); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1174,7 +1161,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + TF_DCHECK_OK( + VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); MaybeDumpModule("after copy insertion", *module); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 5ba64b78a3c9aff5f323691df2ece9b5e6bf3232..c097089e30d59936a32f69c49123c398f1611ea3 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -43,9 +43,9 @@ namespace xla { // (3) The buffer set of the root instruction of the entry computation must be // unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and // InstructionAliasSet::IsDistinct return true. -class CopyInsertion : public HloPassInterface { +class CopyInsertion : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } // fusion_can_share_buffer: backend specific function that decides whether a // fusion can share buffer with its operand. @@ -77,15 +77,29 @@ class CopyInsertion : public HloPassInterface { Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module); - private: - // Verifies that no HLO values have interfering live ranged assuming the - // ordering used by copy insertion. - Status VerifyNoLiveRangeInterference(HloModule* module); + // Add copies to address special constraints on the roots of computations not + // related to live range interference: + // + // (1) Entry computation root must be unambiguous and distinct. + // + // (2) Any computation called by a kCall instruction must have an + // unambiguous root. + // + // (3) Constants and parameters cannot be live out of the entry computation + // + Status AddSpecialCaseCopies(HloModule* module); - Status AddCopiesToResolveInterference(HloModule* module); + // Verifies that no HLO values have interfering live ranges using the given + // ordering. + Status VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module); + private: + // Override which requires the caller to pass in a call graph. Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + Status AddCopiesToResolveInterference(HloModule* module); + // Backend specific function that decides whether a fusion can share buffer // with its operand. HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cd735256b83f5f1d69a89e693de6064d460a36e5..892d0d7b547aaf1e7f1c55e4163d1e1fd9518def 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2007,5 +2007,46 @@ ENTRY TestComputation { InsertCopies(module.get()); } +TEST_F(CopyInsertionTest, NestedWhiles) { + // Verify that only no unnecessary copies remain after copy insertion for + // trivial nested whiles (b/112472605). + const string& hlo_string = R"( +HloModule TestModule + +cond.inner { + ROOT param.cond.inner = pred[] parameter(0) +} + +body.inner { + param.body.inner = pred[] parameter(0) + ROOT neg = pred[] negate(param.body.inner) +} + +cond.outer { + ROOT param.cond.outer = pred[] parameter(0) +} + +body.outer { + param.cond.outer = pred[] parameter(0) + ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner +} + +ENTRY TestComputation { + entry_param = pred[] parameter(0) + ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 1); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::While(op::Copy(op::Parameter()))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 504b61d134a0099d055d0266408e1dfb94af5b2a..b7103118ac5cbd47e060b170a8e432e2ec93c0fd 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "mkl_deps", ) # Filegroup used to collect source files for dependency checking. @@ -50,16 +50,32 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], alwayslink = True, # Contains per-platform transfer manager registration ) +cc_library( + name = "buffer_info_util", + srcs = ["buffer_info_util.cc"], + hdrs = ["buffer_info_util.h"], + deps = [ + "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], hdrs = ["cpu_compiler.h"], deps = [ ":compiler_functor", + ":buffer_info_util", ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", @@ -73,6 +89,12 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ":target_machine_features", + "@com_google_absl//absl/types:span", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", @@ -87,6 +109,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_feature_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -100,7 +123,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", @@ -158,11 +181,13 @@ cc_library( ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", + ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", + "@com_google_absl//absl/memory", "@llvm//:execution_engine", "@llvm//:core", "@llvm//:mc", # fixdeps: keep @@ -214,6 +239,9 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@llvm//:orc_jit", ], ) @@ -256,11 +284,15 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", @@ -305,6 +337,8 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -315,12 +349,12 @@ cc_library( hdrs = ["parallel_loop_emitter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -347,6 +381,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -367,6 +402,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -380,6 +416,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", @@ -403,6 +440,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:analysis", "@llvm//:core", "@llvm//:ipo", @@ -425,12 +463,16 @@ cc_library( ], copts = runtime_copts(), deps = [ + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "//tensorflow/stream_executor", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", ], ) @@ -484,10 +526,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", - ] + if_mkl([ - "@mkl_dnn", - "//third_party/mkl:intel_binary_blob", - ]), + ] + mkl_deps(), ) cc_library( @@ -541,10 +580,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", "//third_party/eigen3", - ] + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), + ] + mkl_deps(), ) cc_library( @@ -592,6 +628,18 @@ cc_library( ], ) +cc_library( + name = "runtime_key_value_sort", + srcs = ["runtime_key_value_sort.cc"], + hdrs = ["runtime_key_value_sort.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], @@ -625,6 +673,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -637,8 +687,11 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -733,6 +786,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -764,6 +818,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -785,6 +840,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -801,6 +857,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -837,6 +895,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -884,6 +943,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -904,6 +965,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -929,6 +991,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..1942ea1a2af8a349de53bafe80977436f9740fc4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" + +namespace xla { +namespace cpu { + +using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; + +std::vector CreateBufferInfosFromBufferAssignment( + const BufferAssignment& buffer_assignment) { + std::vector buffer_infos; + for (const BufferAllocation& allocation : buffer_assignment.Allocations()) { + if (allocation.is_thread_local()) { + buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size())); + } else if (allocation.is_constant()) { + buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size())); + } else if (allocation.is_entry_computation_parameter()) { + buffer_infos.push_back(BufferInfo::MakeEntryParameter( + /*size=*/allocation.size(), + /*param_number=*/allocation.parameter_number())); + } else { + buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size())); + } + } + return buffer_infos; +} + +std::vector CreateArgIndexTableFromBufferInfos( + absl::Span buffer_infos) { + std::vector result; + for (int64 i = 0; i < buffer_infos.size(); i++) { + if (buffer_infos[i].is_entry_parameter()) { + if (buffer_infos[i].entry_parameter_number() >= result.size()) { + result.resize(buffer_infos[i].entry_parameter_number() + 1); + } + result[buffer_infos[i].entry_parameter_number()] = i; + } + } + return result; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e9ee928ab290f2f5338bd7b3804dc43033e2042f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +namespace xla { +namespace cpu { +// Creates and returns a list of BufferInfo instances containing relevant +// information from `buffer_assignment`. +std::vector<::tensorflow::cpu_function_runtime::BufferInfo> +CreateBufferInfosFromBufferAssignment( + const BufferAssignment& buffer_assignment); + +// Creates and returns a table containing the mapping from entry computation +// parameters to buffer allocation indices. +// +// If this function returns V then entry parameter i has buffer allocation index +// V[i]. +std::vector CreateArgIndexTableFromBufferInfos( + absl::Span + buffer_infos); +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 128eea4828b5e514b2ba6b398898e4a5d228e746..73b03440cbb936017257b8a92f16dcc25d41e21c 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -35,7 +36,6 @@ limitations under the License. #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { llvm::Triple target_triple(target_machine_->getTargetTriple()); auto target_library_info_impl = - MakeUnique(target_triple); + absl::make_unique(target_triple); target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); passes->add( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 0985b9297fe487f3523826cb0978c17775549735..2d9978404cc9ec1e40fc61aaf794a8f1f06050bb 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -130,8 +130,9 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { // change the dimension mapping but not the dimension sizes. For // example, input height and width are the same as before the reshapes. HloInstruction* new_conv = module->entry_computation()->AddInstruction( - HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, - hlo->window(), new_dnums)); + HloInstruction::CreateConvolve( + new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), + hlo->window(), new_dnums, hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index e6fd1499edd0095395194200a5b444ad61e7e39d..becee3f81fc34c73040d53e4a261bc3a656cd78c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -31,14 +31,14 @@ namespace cpu { // called canonical convolutions). This pass expands non-canonical convolutions // into reshapes and canonical convolutions, so that these non-canonical // convolutions can run faster. -class ConvCanonicalization : public HloPassInterface { +class ConvCanonicalization : public HloModulePass { public: explicit ConvCanonicalization( const TargetMachineFeatures* target_machine_features) : target_machine_features_(*target_machine_features) {} ~ConvCanonicalization() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-canonicalization"; } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 547d4c696da5cfdde3dece03250ae5fa51c92f25..2083f440fdd971db1b675d005664d25e6de53dbe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloTestBase { +class ConvCanonicalizationTest : public HloVerifiedTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -84,7 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -95,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -146,7 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -156,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 8cbe9a1b0d5b0553b1121d544196412f36f8ce43..18fc144efe0023c0893adfcb16eda3341c0938d3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -26,6 +26,8 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -42,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" @@ -50,6 +51,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" @@ -74,12 +77,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" @@ -87,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -98,11 +102,10 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { +using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; CpuAotCompilationOptions::CpuAotCompilationOptions( string triple, string cpu_name, string features, string entry_point_name, @@ -120,11 +123,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { } CpuAotCompilationResult::CpuAotCompilationResult( - ObjectFileData object_file_data, BufferSizes buffer_sizes, + ObjectFileData object_file_data, std::vector buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), - buffer_sizes_(std::move(buffer_sizes)), + buffer_infos_(std::move(buffer_infos)), result_buffer_index_(result_buffer_index), hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} @@ -231,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map* hlo_to_profile_idx_; const std::unordered_map& assigned_indices_; }; -} // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine) { - LLVMTargetMachineFeatures target_machine_features(target_machine); +} // namespace - // Optimization pipeline. - HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker(); +Status CpuCompiler::RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes through layout assignment"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -255,11 +258,13 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(&target_machine_features); + pipeline.AddPass(); + pipeline.AddPass(target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass( /*rewrite_training_op=*/true, @@ -273,7 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. - pipeline.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -286,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, } pipeline.AddPass(); pipeline.AddPass( - [&target_machine_features]( - const HloInstruction& dot, + [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, target_machine_features) + return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -297,17 +301,35 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), target_machine_features); + return pipeline.Run(module).status(); +} + +Status CpuCompiler::RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes after layout assignment"); + // After layout assignment, use a layout-sensitive verifier. + auto& after_layout_assn = + pipeline.AddPass("after layout assignment"); + after_layout_assn.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass>( - "after layout assignement"); + "simplification after layout assignement"); + pass.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -315,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } + pipeline.AddPass(BF16, F32); + // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -328,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass( - max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); + max_parallelism, ShapeSizeBytesFunction(), target_machine_features); } - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). DCE must be run immediately + // before (and sometime after) copy insertion, to avoid dead code from + // interfering with the rewrites. pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -343,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, return pipeline.Run(module).status(); } +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile, + &target_machine_features)); + return RunHloPassesAfterLayoutAssn(module, is_aot_compile, + &target_machine_features); +} + namespace { // Align buffers to 16-byte boundaries. @@ -354,7 +387,7 @@ llvm::TargetOptions CompilerTargetOptions( llvm::TargetOptions target_options; llvm_ir::SetTargetOptions( /*fast_math_enabled=*/module_config.debug_options() - .xla_enable_fast_math(), + .xla_cpu_enable_fast_math(), &target_options); return target_options; } @@ -446,7 +479,7 @@ Status CreateHloProfilingArtifacts( computation_to_profile_idx, std::unique_ptr* hlo_profile_index_map, std::unique_ptr* hlo_profile_printer_data) { - *hlo_profile_index_map = MakeUnique(module); + *hlo_profile_index_map = absl::make_unique(module); const HloComputation& entry_computation = *module.entry_computation(); TF_ASSIGN_OR_RETURN( @@ -513,15 +546,15 @@ StatusOr> CpuCompiler::RunBackend( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = xla::MakeUnique(); + auto llvm_context = absl::make_unique(); auto llvm_module = - xla::MakeUnique("__compute_module", *llvm_context); + absl::make_unique("__compute_module", *llvm_context); - auto jit = xla::MakeUnique( + auto jit = absl::make_unique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_hook, post_optimization_ir_hook); llvm_module->setDataLayout(jit->data_layout()); @@ -551,20 +584,17 @@ StatusOr> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. + // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module.get(), + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -595,9 +625,10 @@ StatusOr> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -605,9 +636,10 @@ StatusOr> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -651,9 +683,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. const bool fast_math_enabled = - modules[0]->config().debug_options().xla_enable_fast_math(); + modules[0]->config().debug_options().xla_cpu_enable_fast_math(); for (const auto& module : modules) { - if (module->config().debug_options().xla_enable_fast_math() != + if (module->config().debug_options().xla_cpu_enable_fast_math() != fast_math_enabled) { return InvalidArgument( "All HLO module configs must have the same value for " @@ -672,8 +704,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", - error.c_str()); + return InternalError("TargetRegistry::lookupTarget failed: %s", error); } llvm::Reloc::Model reloc_model = llvm::Reloc::Static; @@ -709,7 +740,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr target_machine = WrapUnique( + std::unique_ptr target_machine = absl::WrapUnique( target->createTargetMachine(triple.getTriple(), cpu_name, features, CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, opt_level)); @@ -740,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module, - xla::MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -793,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_top_level_computation=*/true, - &module_sequence.at(computation))); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + &schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); @@ -830,7 +859,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); std::unique_ptr object_file = @@ -838,39 +867,14 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ObjectFileData object_file_data(object_file->getBufferStart(), object_file->getBufferEnd()); - BufferSizes buffer_sizes; - for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate anything for thread-local temporary - // buffers. They are lowered to allocas. - if (allocation.is_thread_local()) { - buffer_sizes.push_back(-1); - continue; - } - - // Callers don't need to allocate anything for constant buffers. They are - // lowered to globals. - if (allocation.is_constant()) { - buffer_sizes.push_back(-1); - continue; - } - - // Callers don't need to allocate anything for entry computation buffers, - // but they do need to stash the pointer to the entry computation buffer - // in the temp buffer table. See the comment on - // XlaCompiledCpuFunction::StaticData::temp_sizes. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-allocation.parameter_number() - 2); - continue; - } - - buffer_sizes.push_back(allocation.size()); - } + std::vector buffer_infos = + CreateBufferInfosFromBufferAssignment(*assignment); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment->GetUniqueTopLevelOutputSlice()); - results.emplace_back(MakeUnique( - std::move(object_file_data), std::move(buffer_sizes), + results.emplace_back(absl::make_unique( + std::move(object_file_data), std::move(buffer_infos), result_slice.index(), std::move(hlo_profile_printer_data))); } @@ -892,7 +896,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::host::kHostPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index e56f9f01134f84b4698c078b750b0c1fdca7748e..f2af923782df268e3e6da3895ec35579ab6aa51f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,12 +18,14 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -78,7 +80,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions { class CpuAotCompilationResult : public AotCompilationResult { public: CpuAotCompilationResult( - ObjectFileData object_file_data, BufferSizes buffer_sizes, + ObjectFileData object_file_data, + std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); @@ -88,17 +91,20 @@ class CpuAotCompilationResult : public AotCompilationResult { } const ObjectFileData& object_file_data() const { return object_file_data_; } - const BufferSizes& buffer_sizes() const { return buffer_sizes_; } + const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>& + buffer_infos() const { + return buffer_infos_; + } int64 result_buffer_index() const { return result_buffer_index_; } private: // Contains the compiled computation: an object file. const ObjectFileData object_file_data_; - // The list of buffer sizes which should be allocated in order to execute the - // compiled computation. These buffers are used for temporary buffers used - // ephemerally during computation as well as the output result. - const BufferSizes buffer_sizes_; + // A list of BufferInfo objects describing the buffers used by the XLA + // computation. + const std::vector<::tensorflow::cpu_function_runtime::BufferInfo> + buffer_infos_; // Contains which buffer index into |buffer_sizes| was designated to the // result of the computation. This buffer should be passed into the output @@ -152,6 +158,16 @@ class CpuCompiler : public LLVMCompiler { Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine); + // Runs HLO passes up to and including layout assignment. + Status RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features); + + // Runs HLO passes after layout assignment. + Status RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features); + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index 3313d1e6eb71bff39f509c3d24858568df786422..076235f8874b5de57075fb690dd1b9111b6838a6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -30,13 +30,13 @@ namespace xla { // // TODO(b/62548313): Remove this when buffer assignment is smarter // (module-scoped). -class CpuCopyInsertion : public HloPassInterface { +class CpuCopyInsertion : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 4db7fa446ea9188940f930bcadf753bd3e6b79e3..c9fb34be1cd582c71618c770c892058c233c571a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloTestBase { +class CpuCopyInsertionTest : public HloVerifiedTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*module), 3); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 946f5124b87bc011df4f3553077dbb37a3333ed2..29abf38e439d919ff93629ed992cb3ff93a929bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -35,9 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -75,9 +75,9 @@ CpuExecutable::CpuExecutable( StatusOr, std::vector>> -CpuExecutable::CreateTempArray( +CpuExecutable::CreateBufferTable( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); std::vector owning_buffers( @@ -136,19 +136,19 @@ CpuExecutable::CreateTempArray( Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, + absl::Span buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, - // void** temps_array) + // void** buffer_table) // // result: Points at the result. // run_options: the ExecutableRunOptions object. // args_array: null - // temps_array: An array of pointers, containing pointers to temporary buffers - // required by the executable adn pointers to entry computation - // parameters. + // buffer_table: An array of pointers, containing pointers to temporary + // buffers required by the executable adn pointers to entry computation + // parameters. // uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -171,20 +171,19 @@ Status CpuExecutable::ExecuteComputeFunction( void* result_buffer = buffer_pointers[result_slice.index()]; if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; - VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[null], void* temps[%zu], " - "uint64 profile_counters[%zu])", + VLOG(3) << absl::StrFormat( + " func(void* result, void* params[null], void* buffer_table[%u], " + "uint64 profile_counters[%u])", buffer_pointers.size(), profile_counters_size); - VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); + VLOG(3) << absl::StrFormat(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { - tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); + absl::StrAppend(out, absl::StrFormat("%p", p)); }; VLOG(3) << " params = nullptr"; - VLOG(3) << tensorflow::strings::Printf( - " temps = [%s]", - tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); - VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters); + VLOG(3) << absl::StrFormat( + " buffer_table = [%s]", + absl::StrJoin(buffer_pointers, ", ", ptr_printer)); + VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters); } compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), @@ -209,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction( StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers) { + absl::Span buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/result_shape(), @@ -247,36 +246,33 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - - std::vector owning_buffers; - std::vector unowning_buffers; TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - - return CreateResultShapedBuffer(run_options, &owning_buffers); + auto result, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); + return std::move(result); } StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " "supported on CPU."); } + return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); +} + +StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile) { + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } auto* host_stream = dynamic_cast( run_options->stream()->implementation()); @@ -286,11 +282,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( std::vector unowning_buffers; TF_ASSIGN_OR_RETURN( std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), + arguments)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &owning_buffers)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers))); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -303,26 +300,27 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( // // We also need to change the types of some of the variables we capture: // run_options needs to change from a pointer to a value type, and arguments - // needs to change from an ArraySlice into a vector. We use a struct instead + // needs to change from a Span into a vector. We use a struct instead // of a lambda to make this explicit. struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; std::vector unowning_buffers; std::shared_ptr> buffers; + HloExecutionProfile* hlo_execution_profile; void operator()() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), unowning_buffers, - /*hlo_execution_profile=*/nullptr)); + &run_options.run_options(), unowning_buffers, hlo_execution_profile)); } }; host_stream->EnqueueTask( AsyncRunTask{this, *run_options, std::move(unowning_buffers), std::make_shared>( - std::move(owning_buffers))}); + std::move(owning_buffers)), + hlo_execution_profile}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8af8a5dfec2834678418f069619ba88b01633361..3c3c047bfe8ee0d1ad90ede2432a86264f47870b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -57,12 +57,12 @@ class CpuExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -74,9 +74,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. - using ComputeFunctionType = void (*)( - void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); + using ComputeFunctionType = + void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*buffer_table*/, + int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -85,15 +86,25 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Creates an array suitable for passing as the "temps" argument to the JIT - // compiled function pointer. + // This is for sharing the code between ExecuteOnStream and + // ExecuteAsyncOnStream. + // + // Notice that it's tricky to use correctly, as the profile object (when it + // exists) must out-live the task. + StatusOr ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile); + + // Creates an array suitable for passing as the "buffer_table" argument to the + // JIT compiled function pointer. // // Returns (unowning_buffers, owning_buffers) where: // - // - unowning_buffers.data() can be passed as the temps argument as-is and - // includes pointers to the scratch storage required by the computation, - // the live-out buffer into which the result will be written and entry - // computation parameters. + // - unowning_buffers.data() can be passed as the buffer_table argument as-is + // and includes pointers to the scratch storage required by the + // computation, the live-out buffer into which the result will be written + // and entry computation parameters. // // - owning_buffers contains owning pointers to the buffers that were // allocated by this routine. This routine allocates buffers for temporary @@ -101,22 +112,21 @@ class CpuExecutable : public Executable { // result. StatusOr, std::vector>> - CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments); + CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + absl::Span arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunction(const ExecutableRunOptions* run_options, + absl::Span buffers, + HloExecutionProfile* hlo_execution_profile); // Creates a ScopedShapedBuffer for holding the result of the computation, // moving buffers out of allocated_buffers and into the result as appropriate. // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers); + absl::Span buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7bd4741a04b1135d9780e0cf765b7b33378526e1..7fbe0fa157c57eb0c274662a1de95cf5328ccfa8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr CpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "CPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 2924b6365943f0a3ec998d7a77767a76cbb576ae..a39a9d4724655370454d60fbb7b474f223bd8a85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -23,14 +23,12 @@ namespace xla { // This pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the CPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class CpuHloSupportChecker : public HloPassInterface { +class CpuHloSupportChecker : public HloModulePass { public: CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "cpu_hlo_support_checker"; - } + absl::string_view name() const override { return "cpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index 0f463e6de623fc6ab43d685ff2a5d6882ba7b8a2..be1208fb2df2a1a11a093810b5f6c2a83f468062 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloTestBase { +class CpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index b40d264c03aba6e9308e8a621ae86e180e33c335..f9cd61bea3dc86cadff99d4a90eca44c16520823 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -35,7 +35,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kSlice || @@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (!CanBeLoopFused(*producer)) { - VLOG(2) << "Producer is not fusile."; + VLOG(2) << "Producer is not fusible."; return false; } @@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (CanBeLoopFused(*consumer)) { - VLOG(2) << "Fusing: consumer is elementwise or fusile."; + VLOG(2) << "Fusing: consumer is elementwise or fusible."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 991b14f17dbc8cd061af98e032824d3f7075e78b..7d99b914d4f5e5d27722bcd098d2ae0c54a36a23 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -18,11 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace op = xla::testing::opcode_matchers; @@ -37,7 +39,11 @@ std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + precision_config); } TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { @@ -566,7 +572,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { HloOpcode::kParameter, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, MessOfFusileNodes) { +TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -691,14 +697,15 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, auto* addend = builder.AddInstruction( HloInstruction::CreateParameter(2, dot_shape, "param2")); - auto* dot = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + auto* dot = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); builder.AddInstruction( HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); if (add_extra_use_for_dot) { + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( - HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config")); } module->AddEntryComputation(builder.Build()); @@ -772,8 +779,8 @@ class GatherLoopFusionTest TEST_P(GatherLoopFusionTest, GatherLoopFusion) { const GatherLoopFusionTestSpec& spec = GetParam(); - string hlo_string = tensorflow::strings::StrCat( - "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n", + spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_string)); @@ -791,11 +798,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -807,11 +814,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -823,11 +830,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -839,11 +846,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -855,11 +862,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -871,11 +878,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -887,11 +894,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index aa872d5ec9e7593b8d2f731421c17af590729529..bfecbd6e017893e4f6d3dcbc01d46c899e6060fa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -34,8 +34,8 @@ namespace cpu { // instruction stream. namespace { -using ::tensorflow::gtl::nullopt; -using ::tensorflow::gtl::optional; +using absl::nullopt; +using absl::optional; using ShouldMakeOperandColMajorCache = tensorflow::gtl::FlatMap; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 3681d12d8da818d06d2f690024008c9ccb896286..4668f3872dad598edf4c7680e1b601622104ab3e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -276,8 +276,8 @@ static StatusOr RunDotOutputFusion( HloInstruction::CreateParameter(1, dot_shape, "param1")); HloInstruction* dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); - HloInstruction* dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* dot_result = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); HloInstruction* add_result; if (dot_operand_idx_in_add == 0) { add_result = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 3ed7876715f64191f6e652d2b5cb1673df9a1b94..b8ace5702688096822573c7afae234cbcbe77b28 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -45,17 +46,16 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) { return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; } -tensorflow::gtl::optional LlvmIrGemvTilingFactor( - const HloModuleConfig& config) { +absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); auto it = extra_options_map.find(kLlvmIrDotTilingFactor); int64 tiling_factor; if (it != extra_options_map.end() && - tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + absl::SimpleAtoi(it->second, &tiling_factor)) { return tiling_factor; } - return tensorflow::gtl::nullopt; + return absl::nullopt; } bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { @@ -64,38 +64,37 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } -static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, - tensorflow::StringPiece suffix) { +static absl::string_view RemoveSuffix(absl::string_view str, + absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); return str.substr(0, str.size() - suffix.size()); } -tensorflow::gtl::optional> LlvmIrGemmTileSize( +absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); auto it = extra_options_map.find(kLlvmIrGemmTileSize); if (it == extra_options_map.end()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } - std::vector tile_components = - tensorflow::str_util::Split(it->second, ':'); + std::vector tile_components = absl::StrSplit(it->second, ':'); CHECK_EQ(tile_components.size(), 3); int64 tile_size_m; int64 tile_size_k; int64 tile_size_n_in_vector_width; - CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); - CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m)); + CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k)); - tensorflow::StringPiece tile_size_n_in_vector_width_str = + absl::string_view tile_size_n_in_vector_width_str = RemoveSuffix(tile_components[2], "*vectwidth"); - CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, - &tile_size_n_in_vector_width)); + CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); return std::tuple(tile_size_m, tile_size_k, tile_size_n_in_vector_width); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 429b9e16cbdd6f623919533582481f1640118081..47c7eb13b6e4cc05a23f82b8d2a25249f4b82ac0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,9 +27,8 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); -tensorflow::gtl::optional LlvmIrGemvTilingFactor( - const HloModuleConfig& config); -tensorflow::gtl::optional> LlvmIrGemmTileSize( +absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); +absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); } // namespace options diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 639064040f521a9e84bd87c5d05f674204e4d6e2..20cf8557354b161451cf5b7825ccfce57d96875a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -17,18 +17,29 @@ limitations under the License. #include +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace cpu { namespace runtime { -XfeedManager* GetXfeedManager() { - static XfeedManager* manager = new XfeedManager; - return manager; +XfeedManager* GetXfeedManager(int device_ordinal) { + static tensorflow::gtl::FlatMap* managers = + new tensorflow::gtl::FlatMap(); + static absl::Mutex* mutex = new absl::Mutex(); + + absl::MutexLock lock(mutex); + auto it = managers->find(device_ordinal); + if (it == managers->end()) { + it = managers->emplace(device_ordinal, new XfeedManager()).first; + } + return it->second; } extern const char* const kEigenMatMulF16SymbolName = @@ -73,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; +extern const char* const kKeyValueSortPREDSymbolName = + "__xla_cpu_runtime_KeyValueSortPRED"; +extern const char* const kKeyValueSortS8SymbolName = + "__xla_cpu_runtime_KeyValueSortS8"; +extern const char* const kKeyValueSortU8SymbolName = + "__xla_cpu_runtime_KeyValueSortU8"; +extern const char* const kKeyValueSortS16SymbolName = + "__xla_cpu_runtime_KeyValueSortS16"; +extern const char* const kKeyValueSortU16SymbolName = + "__xla_cpu_runtime_KeyValueSortU16"; +extern const char* const kKeyValueSortF16SymbolName = + "__xla_cpu_runtime_KeyValueSortF16"; +extern const char* const kKeyValueSortS32SymbolName = + "__xla_cpu_runtime_KeyValueSortS32"; +extern const char* const kKeyValueSortU32SymbolName = + "__xla_cpu_runtime_KeyValueSortU32"; +extern const char* const kKeyValueSortF32SymbolName = + "__xla_cpu_runtime_KeyValueSortF32"; +extern const char* const kKeyValueSortS64SymbolName = + "__xla_cpu_runtime_KeyValueSortS64"; +extern const char* const kKeyValueSortU64SymbolName = + "__xla_cpu_runtime_KeyValueSortU64"; +extern const char* const kKeyValueSortF64SymbolName = + "__xla_cpu_runtime_KeyValueSortF64"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime @@ -93,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireInfeedBufferForDequeue: " - << ShapeString(shape, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireInfeedBufferForDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireInfeedBufferForDequeue: " + << ShapeString(shape, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->infeed()->BlockingDequeueBuffer(); @@ -113,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseInfeedBufferAfterDeque: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, @@ -129,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireOutfeedBufferForPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireOutfeedBufferForPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->outfeed()->BlockingDequeueBuffer(); @@ -149,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index aa0e96712302e806a389c6ad05a2c1b6634ef901..b2e760a224ad8eaa61dae57b0f9cece04a7e54ae 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -26,6 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/types.h" @@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; +extern const char* const kKeyValueSortPREDSymbolName; +extern const char* const kKeyValueSortS8SymbolName; +extern const char* const kKeyValueSortU8SymbolName; +extern const char* const kKeyValueSortS16SymbolName; +extern const char* const kKeyValueSortU16SymbolName; +extern const char* const kKeyValueSortF16SymbolName; +extern const char* const kKeyValueSortS32SymbolName; +extern const char* const kKeyValueSortU32SymbolName; +extern const char* const kKeyValueSortF32SymbolName; +extern const char* const kKeyValueSortS64SymbolName; +extern const char* const kKeyValueSortU64SymbolName; +extern const char* const kKeyValueSortF64SymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. extern const char* const kXlaCpuRuntimeSymbolNamePrefix; -// Returns the infeed manager used by the CPU runtime. -XfeedManager* GetXfeedManager(); +// Returns the infeed manager used by the CPU runtime for the CPU device +// `device_ordinal`. Note the device ordinal does not name a CPU +XfeedManager* GetXfeedManager(int device_ordinal); } // namespace runtime } // namespace cpu @@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager(); extern "C" { +// Some things common to all of the runtime entry points below: +// +// * The shape pointer and shape_length reflect values that can be deserialized +// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass +// reified type information from the generated program to the runtime, which +// helps check the type safety and contract for the emitted-code/runtime +// communication. +// +// * run_options is used to look up the device ordinal for the stream executor +// we're executing under. If it is null the device ordinal is assumed to be +// 0 (this behavior helps in writing tests). + // Note: in the runtime entry points below, the shape pointer and shape_length // reflect values that can be deserialized via // llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified @@ -89,7 +115,8 @@ extern "C" { // the length would be more exact, but the length check is chosen as a // tradeoff between error checking and speed/simplicity. extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - xla::int32 buffer_length, const void* shape, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length); // Relinquishes the next infeed buffer that was returned by // __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call @@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( // implemented we will add support for multiple outstanding buffers // that can be returned out of order. extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); // Blocks until the next outfeed buffer is available to be populated, then // returns it. extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length); // Relinquishes the outfeed buffer after it has been populated. // buffer_ptr must have been previously returned by @@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( // acquired, i.e., there may only be one outstanding outfeed buffer in // use by the runtime. extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); } // extern "C" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 2ac950e6d93ade315808f2ca1d0bdd7bc85f53b9..1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -46,7 +46,7 @@ std::unique_ptr> MaybeTransposeArray2D(const Array2D& array, if (transpose) { std::swap(output_width, output_height); } - auto output = MakeUnique>(output_height, output_width); + auto output = absl::make_unique>(output_height, output_width); for (int y = 0; y < array.height(); y++) { for (int x = 0; x < array.width(); x++) { if (transpose) { @@ -93,7 +93,7 @@ std::unique_ptr> EigenMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it. Swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_EigenSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), @@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -204,7 +204,7 @@ std::unique_ptr> MKLMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it, swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_MKLSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 156166bf2b1ea6d3821da8f67ea2b2eca6825ca6..1cc2844470376ceb61601f6d1361def84eac5b45 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { @@ -103,7 +105,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( if (ShapeUtil::IsNestedTuple(shape)) { return Unimplemented( "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + ShapeUtil::HumanString(literal.shape())); } // For a tuple, we transfer each of its elements to the device and @@ -127,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed( buffers.push_back(buffer); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers); cleanup.release(); @@ -140,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, size, source)); - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer}); return Status::OK(); @@ -151,11 +155,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Infeed shape must have positive size; got %lld", + return InvalidArgument("Infeed shape must have positive size; got %d", size); } @@ -173,26 +177,24 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { if (!ShapeUtil::IsTuple(literal_shape)) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( + absl::Span dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*LiteralUtil::CreateFromDimensions( - literal_shape.element_type(), dimensions)); - TF_ASSIGN_OR_RETURN(Shape received_shape, - TransferArrayBufferFromOutfeed( - executor, literal->untyped_data(), size)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape_do_not_use() = received_shape; + *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -201,22 +203,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( "Nested tuple outfeeds are not yet implemented on CPU."); } - std::vector> elements; std::vector> buffer_data; for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = ShapeUtil::GetTupleElementShape(literal_shape, i); - // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( - tensorflow::bit_cast( - tuple_element_shape.dimensions().data()), - tuple_element_shape.dimensions().size()); - auto empty = LiteralUtil::CreateFromDimensions( - tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->untyped_data(), size}); - elements.push_back(std::move(empty)); + buffer_data.push_back({literal.untyped_data({i}), size}); } TF_ASSIGN_OR_RETURN(Shape received_shape, @@ -230,17 +222,13 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == GetByteSizeRequirement(received_shape)); - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); - } - *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements))); - TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); return Status::OK(); } StatusOr CpuTransferManager::TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data) { + absl::Span> buffer_data) { return TransferBuffersFromOutfeedInternal(executor, buffer_data, /*is_tuple=*/true); } @@ -253,18 +241,17 @@ StatusOr CpuTransferManager::TransferArrayBufferFromOutfeed( StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple) { + absl::Span> buffer_data, bool is_tuple) { std::vector> buffers; for (auto b : buffer_data) { int64 size = b.second; if (size > std::numeric_limits::max()) { - return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + return InvalidArgument("Outfeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %lld", + return InvalidArgument("Outfeed shape must have positive size; got %d", size); } @@ -272,7 +259,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " << size_32 << "B"; - buffers.emplace_back(MakeUnique(b.first, size_32)); + buffers.emplace_back(absl::make_unique(b.first, size_32)); } std::vector buffer_pointers; @@ -281,7 +268,8 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( buffer_pointers.push_back(b.get()); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers); VLOG(2) << "Waiting for buffer to be notified as populated."; std::vector outfed_shapes; @@ -299,7 +287,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( } // namespace xla static std::unique_ptr CreateCpuTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 593575c0fdaddc71cd6bd844fd179096a9fb0fdc..361d4b9c8422fff6afe53e56e0bb10a484c9becc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ #include +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, @@ -55,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager { // Helper that transfers a tuple of element buffers from the device's outfeed. StatusOr TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data); + absl::Span> buffer_data); // Helper that transfers an array buffer from the device's outfeed. StatusOr TransferArrayBufferFromOutfeed(se::StreamExecutor* executor, @@ -67,12 +68,11 @@ class CpuTransferManager : public GenericTransferManager { // for the given buffers. StatusOr TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple); + absl::Span> buffer_data, bool is_tuple); TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index e4c674e227ffc6725ca929f720b9aa7cf7c4c032..3ae64142cd7e32d3aa8d50870efaf94698c06440 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "llvm/MC/MCInst.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -151,7 +151,7 @@ StatusOr Disassembler::DisassembleObjectFile( size = 1; } - ostream << tensorflow::strings::Printf("0x%08lx", index) << " "; + ostream << absl::StrFormat("0x%08lx", index) << " "; if (decode_status == llvm::MCDisassembler::Success) { // For branches, try to determine the actual address and emit it as an @@ -163,7 +163,7 @@ StatusOr Disassembler::DisassembleObjectFile( uint64_t target; if (inst_analysis_->evaluateBranch( instruction, section_address + index, size, target)) { - annotation = tensorflow::strings::Printf("[0x%08lx]", target); + annotation = absl::StrFormat("[0x%08lx]", target); } } inst_printer_->printInst(&instruction, ostream, annotation.c_str(), diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 645888de783e4025cffd6fa4835e60b84bbd7d99..99fa707c959854e50c6d954fe92b87e93e267dc6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -79,7 +80,7 @@ class MemoryTile { // `minor_dim_offset`}. // // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(tensorflow::gtl::ArraySlice tile, + void StoreTile(absl::Span tile, llvm::Value* minor_dim_offset) const { CHECK_EQ(tile.size(), pointers_.size()); for (int64 i = 0; i < pointers_.size(); i++) { @@ -146,9 +147,9 @@ class GemvConfig { bool has_addend() const { return has_addend_; } string GetCacheKey() const { - return tensorflow::strings::StrCat( - name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", - tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); } protected: @@ -621,19 +622,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // This class implements a tiled matrix multiplication algorithm, intended for -// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, -// Kazushige, and Robert Van De Geijn. "High-performance implementation of the -// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): -// 4). +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". // // This only supports canonical dot operations (i.e. where the lhs contraction // dimension is 1 and the rhs contraction dimension is 0) over row major // matrices. -class MatrixMatrixBlockPanelEmitter { +class TiledSmallGemmEmitter { public: - // Describe the dimensions of the GEBP kernel. These will usually not be the - // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP - // kernels with smaller dimensions. + // Describe the dimensions of the kernel. class Dimensions { public: explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} @@ -642,9 +643,7 @@ class MatrixMatrixBlockPanelEmitter { int64 k() const { return k_; } int64 n() const { return n_; } - string ToString() const { - return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); - } + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } private: const int64 m_; @@ -652,9 +651,9 @@ class MatrixMatrixBlockPanelEmitter { const int64 n_; }; - // Represents the configuration of the GEBP emitter. The LLVM IR emitted by - // the emitter, modulo the LLVM values holding the input and output buffers, - // must be a function of the instance of `Config` passed to it. + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. // // `dims` holds the matrix multiplication dimensions. // @@ -687,10 +686,10 @@ class MatrixMatrixBlockPanelEmitter { tile_size_k_(tile_size_k) {} string GetCacheKey() const { - return tensorflow::strings::StrCat( - "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), - "_", max_vectorization_width(), "_", min_vectorization_width(), "_", - tile_size_m(), "_", tile_size_k()); + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); } PrimitiveType scalar_type() const { return scalar_type_; } @@ -712,11 +711,11 @@ class MatrixMatrixBlockPanelEmitter { int64 tile_size_k_; }; - // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies // `lhs` with `rhs` and stores the result in `result`. - explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) : lhs_(lhs), rhs_(rhs), result_(result), @@ -780,9 +779,9 @@ class MatrixMatrixBlockPanelEmitter { KernelSupportLibrary ksl_; }; -void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { +void TiledSmallGemmEmitter::HandleResiduesOnN() { // We can only iterate the `n` dimension for an extent that is divisible by // the vectorization width. So we emit an outer loop that first processes the // largest extent in `n` that is divisible by max_vectorization_width, then @@ -799,7 +798,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gebp"); + "gemm"); HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } @@ -813,7 +812,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); @@ -821,9 +820,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { int64 k_start = 0; int64 k_end = dims().k() - (dims().k() % tile_size_k()); if (k_end != k_start) { @@ -838,7 +837,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, } } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( +void TiledSmallGemmEmitter::HandleResiduesOnM( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { const int64 m_end = dims().m() - dims().m() % tile_size_m(); @@ -921,7 +920,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( // +-------------------+-------------------+-------------------+--------- // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... // +-------------------+-------------------+-------------------+--------- -void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( +void TiledSmallGemmEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { @@ -1001,12 +1000,22 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, return dot_emitter.Emit(); } -bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( +bool DotOpEmitter::EmitSmallGemmIfProfitable( const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + if (ShouldUseMultiThreadedEigen()) { return false; } + if (!EnableExperimentalLlvmIrGemm()) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = mat_mult_dims.k <= 128 && + ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || + (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); + if (!small_gemm) { + return false; + } + } + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { return false; } @@ -1054,19 +1063,19 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - MatrixMatrixBlockPanelEmitter::Config config( + TiledSmallGemmEmitter::Config config( /*scalar_type=*/primitive_type, - MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, /*max_vectorization_width=*/max_target_vector_width, /*max_vector_count=*/tile_size_n_in_vector_width, /*min_vectorization_width=*/std::min(4, max_target_vector_width), /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " << config.GetCacheKey(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); @@ -1075,10 +1084,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, rhs, target, [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - gebp_emitter.Emit(); + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/target, b_); + small_gemm_emitter.Emit(); }); return true; @@ -1136,7 +1145,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); + return EmitSmallGemmIfProfitable(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -1149,7 +1158,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); @@ -1458,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() { break; default: return Unimplemented("Invalid type %s for dot operation", - PrimitiveType_Name(type).c_str()); + PrimitiveType_Name(type)); } llvm::Type* float_ptr_type = float_type->getPointerTo(); @@ -1610,7 +1619,7 @@ bool PotentiallyImplementedAsEigenDot( // For vector-matrix dot products, it is always profitable to make the Rhs // column major. -tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( +absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 590032fbe907d7ca90bf69b7ccc3170b8efec72e..4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -38,7 +38,7 @@ bool PotentiallyImplementedAsEigenDot( // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot // or a fusion containing a dot. -tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( +absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation @@ -121,7 +121,7 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; - bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index c13d36776f94221598338dca4eadf024c0a892df..c8312d80bd5012e5bcb42a410db18a7fa77a2eb6 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -30,56 +30,16 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - switch (op->opcode()) { - case HloOpcode::kTanh: { - PrimitiveType element_type = op->shape().element_type(); - bool cast_result_to_fp16 = false; - string function_name; - switch (element_type) { - case F16: - cast_result_to_fp16 = true; - operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy()); - TF_FALLTHROUGH_INTENDED; - case F32: - function_name = "tanhf"; - break; - case F64: - function_name = "tanh"; - break; - default: - return Unimplemented("tanh"); - } - // Create a function declaration. - llvm::Function* function = - llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), operand_value->getType(), - operand_value->getType())); - function->setCallingConv(llvm::CallingConv::C); - function->setDoesNotThrow(); - function->setDoesNotAccessMemory(); - // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, operand_value); - if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); - } - return result; - } - default: - return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); - } -} - -StatusOr CpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { string function_name; bool cast_result_to_fp16 = false; switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); - rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); + lhs = FPCast(lhs, b_->getFloatTy()); + rhs = FPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -99,16 +59,49 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = Call(function, {lhs, rhs}); + if (cast_result_to_fp16) { + result = FPCast(result, b_->getHalfTy()); + } + return result; +} + +StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { + bool cast_result_to_fp16 = false; + string function_name; + switch (prim_type) { + case F16: + cast_result_to_fp16 = true; + value = FPCast(value, b_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; + case F32: + function_name = "tanhf"; + break; + case F64: + function_name = "tanh"; + break; + default: + return Unimplemented("tanh"); + } + // Create a function declaration. + llvm::Function* function = llvm::cast( + module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + value->getType(), value->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create an instruction to call the function. + llvm::Value* result = Call(function, value); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { if (hlo->opcode() == HloOpcode::kMap) { return [this, hlo, &operand_to_generator]( const llvm_ir::IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 9598a886ab49fcecf5df7bd65f425fe485de3574..e3fba9306b72904803259047fafea245a8e183db 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: - StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; + StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index ca645d3f1da18fb26378a10526c27a7d254896e2..c3e802078385d4724f0da26e8b6c16503e3662a1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -27,6 +27,9 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" @@ -64,11 +67,8 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -99,13 +99,18 @@ IrEmitter::IrEmitter( target_machine_features_(*target_machine_features) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_enable_fast_math())); + .xla_cpu_enable_fast_math())); + Status s = GatherComputationsByAllocationType( + &hlo_module, &thread_local_computations_, &global_computations_); + absl::c_sort(thread_local_computations_); + absl::c_sort(global_computations_); + TF_CHECK_OK(s) << "Should have failed buffer assignment."; } StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); @@ -158,11 +163,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset( - new IrFunction(function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_enable_fast_math(), - module_, &b_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction( + function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, + &b_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -170,9 +175,9 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + BitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } @@ -230,9 +235,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } - return Unimplemented( - "unsupported operand type %s for copy instruction", - PrimitiveType_Name(copy->shape().element_type()).c_str()); + return Unimplemented("unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type())); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -338,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Write the tuple index table. TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, assignment_.GetUniqueSlice(infeed, {0})); - llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, assignment_.GetUniqueSlice(infeed, {1})); - llvm::Value* token_address = EmitTempBufferPointer( + llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); @@ -364,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root // instruction. Target addresses for internal elements can be obtained - // from EmitTempBufferPointer. + // from EmitBufferPointer. llvm::Value* tuple_element_address = - EmitTempBufferPointer(buffer, tuple_element_shape); + EmitBufferPointer(buffer, tuple_element_shape); TF_RETURN_IF_ERROR(EmitXfeedTransfer( XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); @@ -389,7 +393,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int64 length = ByteSizeOf(shape); if (length <= 0 || length > std::numeric_limits::max()) { return InvalidArgument( - "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", length); } @@ -400,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value * shape_ptr, llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); - // The signature of the acquire infeed buffer function is: - // - // (void*)(int32 length); llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* acquire_type = llvm::FunctionType::get( - i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, + i8_ptr_type, + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* acquire_func; @@ -419,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, } acquire_func->setCallingConv(llvm::CallingConv::C); - // The signature of the release infeed buffer function is: - // - // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, + b_.getVoidTy(), + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type, + /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* release_func; @@ -440,27 +443,33 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = b_.CreateCall( - acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = Call( + acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, - /*SrcAlign=*/1, length_32); + MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, - /*SrcAlign=*/1, length_32); + MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, - shape_ptr, b_.getInt32(shape_length)}); + Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + acquired_pointer, shape_ptr, b_.getInt32(shape_length)}); return Status::OK(); } Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { + // Outfeed produces no useful result, but it does return a token[] that can be + // threaded through to other side effecting operations to ensure ordering. In + // the IR emitter we treat this token as a normal u8[] and thus need to insert + // an entry for it in emitted_value_. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed)); + HloInstruction* operand = outfeed->operands()[0]; const Shape& operand_shape = operand->shape(); @@ -485,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { } Status IrEmitter::HandleSort(HloInstruction* sort) { - // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not implemented on CPU."); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto keys_destination_address = + EmitBufferPointer(keys_destination, keys->shape()); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + llvm::Value* values_destination_address = nullptr; + + // The sort is implemented in-place, therefore we first copy the operand + // buffer to the output buffer if they are not the same. + if (keys_destination != GetAllocationSlice(*keys)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type()); + auto source_buffer = GetEmittedValueFor(keys); + int64 keys_size = ByteSizeOf(keys->shape()); + MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, keys_size); + } + if (values != nullptr) { + values_destination_address = + EmitBufferPointer(values_destination, values->shape()); + if (values_destination != GetAllocationSlice(*values)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type()); + auto source_buffer = GetEmittedValueFor(values); + int64 values_size = ByteSizeOf(values->shape()); + MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, values_size); + } + } + + // Normalize the shape and the dimension to sort. + Shape normalized_keys_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + keys->shape()); + int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( + keys->shape().layout())[sort->dimensions(0)]; + + int64 sort_dimension_elements = + normalized_keys_shape.dimensions(physical_dimension_to_sort); + int64 higher_dimensions = 1; + for (int64 i = 0; i < physical_dimension_to_sort; ++i) { + higher_dimensions *= normalized_keys_shape.dimensions(i); + } + int64 lower_dimensions = 1; + for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + i > physical_dimension_to_sort; --i) { + lower_dimensions *= normalized_keys_shape.dimensions(i); + } + + PrimitiveType keys_type = keys->shape().element_type(); + const char* fn_name = nullptr; + llvm::Type* keys_native_type = nullptr; + switch (keys_type) { + case PRED: + fn_name = runtime::kKeyValueSortPREDSymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S8: + fn_name = runtime::kKeyValueSortS8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case U8: + fn_name = runtime::kKeyValueSortU8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S16: + fn_name = runtime::kKeyValueSortS16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case U16: + fn_name = runtime::kKeyValueSortU16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case F16: + fn_name = runtime::kKeyValueSortF16SymbolName; + keys_native_type = b_.getHalfTy()->getPointerTo(); + break; + case S32: + fn_name = runtime::kKeyValueSortS32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case U32: + fn_name = runtime::kKeyValueSortU32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case F32: + fn_name = runtime::kKeyValueSortF32SymbolName; + keys_native_type = b_.getFloatTy()->getPointerTo(); + break; + case S64: + fn_name = runtime::kKeyValueSortS64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case U64: + fn_name = runtime::kKeyValueSortU64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case F64: + fn_name = runtime::kKeyValueSortF64SymbolName; + keys_native_type = b_.getDoubleTy()->getPointerTo(); + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } + + llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( + b_.getVoidTy(), + {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + b_.getInt8PtrTy(), b_.getInt32Ty()}, + /*isVarArg=*/false); + auto* key_value_sort_func = llvm::cast( + module_->getOrInsertFunction(fn_name, key_value_sort_type)); + key_value_sort_func->setCallingConv(llvm::CallingConv::C); + key_value_sort_func->setDoesNotThrow(); + key_value_sort_func->setOnlyAccessesArgMemory(); + Call(key_value_sort_func, + {PointerCast(keys_destination_address, keys_native_type), + b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + b_.getInt64(lower_dimensions), + values != nullptr + ? PointerCast(values_destination_address, b_.getInt8PtrTy()) + : llvm::Constant::getNullValue(b_.getInt8PtrTy()), + b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType( + values->shape().element_type()) + : 0)}); + + if (values != nullptr) { + llvm_ir::EmitTuple(GetIrArrayFor(sort), + {keys_destination_address, values_destination_address}, + &b_, module_); + } + return Status::OK(); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -501,8 +652,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, - tensorflow::StringPiece name) { + absl::Span elemental_operands, absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -519,8 +669,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); - b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); + Store(Load(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); std::vector window_size; @@ -537,22 +687,21 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = - b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to // input_index[i] < bound as an *unsigned* comparison, since a negative // value will wrap to a large positive value. - llvm::Value* index_condition = b_.CreateICmpULT( - input_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + llvm::Value* index_condition = + ICmpULT(input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; } else { - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } } CHECK(in_bounds_condition != nullptr); @@ -565,19 +714,19 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); - b_.CreateStore(result, accumulator_address); + *reduce_window->to_apply(), {Load(accumulator_address), input_value}, + "reducer_function"); + Store(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_address); + return Load(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32})); + /*supported_types=*/{F32, BF16, S32, F16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(reduce_window->window())) { @@ -647,7 +796,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return b_.CreateLoad(init_value_addr); + return Load(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. @@ -667,7 +816,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); @@ -685,15 +834,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( - source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( - operand_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + llvm::Value* strided_index = + NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = + ICmpULT(operand_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -703,7 +851,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -712,38 +860,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* operand_element = Load(operand_address); llvm::Value* result = EmitThreadLocalCall( *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, - "select_function"); + {Load(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -754,8 +901,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = @@ -837,7 +984,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( lhs_llvm_type, "convolution_sum_address", &b_, MinimumAlignmentForPrimitiveType(lhs_element_type)); llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); - b_.CreateStore(constant_zero, sum_address); + Store(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); std::vector kernel_spatial(num_spatial_dims); @@ -846,7 +993,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( loops .AddLoop( 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) + absl::StrCat("k", i)) ->GetIndVarValue(); } llvm::Value* input_feature = @@ -864,11 +1011,11 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::Value* kernel_index, const WindowDimension& window_dim) { llvm::Value* strided_index = - b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = b_.CreateNSWMul( - kernel_index, b_.getInt64(window_dim.window_dilation())); - return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), - b_.getInt64(window_dim.padding_low())); + NSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = + NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); + return NSWSub(NSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -885,9 +1032,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( // Also need to check that the input coordinates are not in one of the // holes created by base dilation. const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { - llvm::Value* remainder = - b_.CreateSRem(input_index, b_.getInt64(base_dilation)); - return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); + return ICmpEQ(remainder, b_.getInt64(0)); }; llvm::Value* in_bounds_condition = b_.getInt1(true); @@ -895,17 +1041,17 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); llvm::Value* dim_not_in_hole = not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); + in_bounds_condition = And(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual // data indices on the lhs. const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { - return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + return SDiv(input_index, b_.getInt64(base_dilation)); }; for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = @@ -930,8 +1076,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() - ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) + ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) : kernel_spatial[i]; } @@ -940,13 +1086,13 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = - b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), - kernel_array.EmitReadArrayElement(kernel_index, &b_)); - llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); - b_.CreateStore(sum, sum_address); + FMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = FAdd(Load(sum_address), product); + Store(sum, sum_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(sum_address); + return Load(sum_address); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1072,34 +1218,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - b_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), - b_.CreateBitCast(lhs_address, ir_ptr_type), - b_.CreateBitCast(rhs_address, ir_ptr_type), - b_.getInt64(input_batch), - b_.getInt64(input_rows), - b_.getInt64(input_cols), - b_.getInt64(input_channels), - b_.getInt64(kernel_rows), - b_.getInt64(kernel_cols), - b_.getInt64(kernel_channels), - b_.getInt64(kernel_filters), - b_.getInt64(output_rows), - b_.getInt64(output_cols), - b_.getInt64(row_stride), - b_.getInt64(col_stride), - b_.getInt64(padding_top), - b_.getInt64(padding_bottom), - b_.getInt64(padding_left), - b_.getInt64(padding_right), - b_.getInt64(lhs_row_dilation), - b_.getInt64(lhs_col_dilation), - b_.getInt64(rhs_row_dilation), - b_.getInt64(rhs_col_dilation), - }); + Call(conv_func, { + GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(convolution), ir_ptr_type), + BitCast(lhs_address, ir_ptr_type), + BitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1159,15 +1303,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - b_.CreateCall( - fft_func, - {GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - b_.CreateBitCast(operand_address, int8_ptr_type), - b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), - b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + Call(fft_func, + {GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(fft), int8_ptr_type), + BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(fft_rank), b_.getInt64(input_batch), + b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1203,11 +1346,11 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { const Shape& operand_shape = crs->operand(i)->shape(); CHECK(ShapeUtil::IsArray(operand_shape)) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); + MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); @@ -1457,7 +1600,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment) { ShardedVector accumulator; accumulator.reserve(accumulator_type.size()); @@ -1466,19 +1609,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast(shard_type)) { initial_value = - b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); + VectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); + AlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), @@ -1500,24 +1643,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = b_.CreateBitCast( + llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - b_.CreateBitCast(input_address, accumulator[i]->getType()); + BitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - b_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); + AlignedLoad(accumulator[i], element_alignment); + auto addend = AlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = reduction_generator(&b_, current_accumulator_value, addend); - b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); + AlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), - input_address_typed, 1); + input_address = ConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } @@ -1526,8 +1669,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { - result_ssa.push_back( - b_.CreateAlignedLoad(accumulator_shard, element_alignment)); + result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1536,25 +1678,25 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = b_.CreateBitCast( - store_address, - llvm::PointerType::getUnqual(value_to_store[i]->getType())); + auto store_address_typed = + BitCast(store_address, + llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = b_.CreateAlignedStore( - value_to_store[i], store_address_typed, alignment); + auto store_instruction = + AlignedStore(value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = b_.CreateConstInBoundsGEP1_32( - value_to_store[i]->getType(), store_address_typed, 1); + store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(), + store_address_typed, 1); } } } StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - gtl::ArraySlice dimensions, HloComputation* function, + absl::Span dimensions, HloComputation* function, string* failure_reason) { if (!ReductionPreservesLayout(*reduce)) { return false; @@ -1620,9 +1762,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); - std::unique_ptr loop = - loop_nest.AddLoop(start_index, end_index, - tensorflow::strings::Printf("dim.%lld", dimension)); + std::unique_ptr loop = loop_nest.AddLoop( + start_index, end_index, absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -1641,9 +1782,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 start_index = 0; int64 end_index = (innermost_dimension_size / vectorization_factor) * vectorization_factor; - std::unique_ptr loop = loop_nest.AddLoop( - start_index, end_index, vectorization_factor, - tensorflow::strings::Printf("dim.%lld", innermost_dimension)); + std::unique_ptr loop = + loop_nest.AddLoop(start_index, end_index, vectorization_factor, + absl::StrFormat("dim.%d", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); @@ -1705,7 +1846,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1713,8 +1854,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); - b_.CreateStore(load_init_value, accumulator_addr); + llvm::Value* load_init_value = Load(init_value_addr); + Store(load_init_value, accumulator_addr); // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1747,18 +1888,22 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( // Apply the reduction function to the loaded value. llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + *reduce->to_apply(), {Load(accumulator_addr), input_element}, "reduce_function"); - b_.CreateStore(result, accumulator_addr); + Store(result, accumulator_addr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); } Status IrEmitter::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported on CPU"); + } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; @@ -1986,7 +2131,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return b_.CreateLoad(padding_value_addr); + return Load(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output @@ -2008,10 +2153,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = b_.CreateMul( - operand_index[i], - b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); - llvm::Value* index = b_.CreateAdd( + llvm::Value* offset = + Mul(operand_index[i], + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } @@ -2098,7 +2243,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument()); HloInstruction* root = computation->root_instruction(); @@ -2113,8 +2258,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { - gtl::ArraySlice operands(custom_call->operands()); - tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); + absl::Span operands(custom_call->operands()); + absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2122,10 +2267,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + PointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = - b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); - b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + InBoundsGEP(operands_alloca, {b_.getInt64(i)}); + Store(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast(module_->getOrInsertFunction( @@ -2137,9 +2282,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); auto* output_address_arg = - b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); + PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); + Call(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2166,8 +2311,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return InternalError( "instruction %s %s does not share slice with " "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), + slice_b.ToString()); } return Status::OK(); }; @@ -2198,15 +2343,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - b_.CreateBr(header_bb); + Br(header_bb); b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); - llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + llvm::Value* while_predicate = ICmpNE( + Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2215,7 +2359,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - b_.CreateCondBr(while_predicate, body_bb, exit_bb); + CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. b_.SetInsertPoint(body_bb); @@ -2224,7 +2368,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); // Finishes with a branch back to the header. - b_.CreateBr(header_bb); + Br(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); @@ -2234,7 +2378,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { } StatusOr IrEmitter::EmitFastConcatenate( - HloInstruction* concatenate, gtl::ArraySlice operands, + HloInstruction* concatenate, absl::Span operands, string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { *failure_reason = @@ -2271,7 +2415,6 @@ StatusOr IrEmitter::EmitFastConcatenate( output_min2maj.end()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); - llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); @@ -2294,9 +2437,9 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = - b_.CreateBitCast(target_array.EmitArrayElementAddress( - outer_dims_index, &b_, "target_region"), - i8_ptr_type); + BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, + "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2310,13 +2453,12 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = b_.CreateBitCast( + llvm::Value* copy_source_address = BitCast( source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = - b_.CreateGEP(i8_type, target_region_begin, - b_.getInt64(byte_offset_into_target_region)); + GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2348,15 +2490,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = b_.CreateAlignedLoad( - b_.CreateBitCast(source, primitive_ptr_type), element_alignment); + auto* load_instruction = + AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = b_.CreateAlignedStore( - load_instruction, b_.CreateBitCast(target, primitive_ptr_type), - element_alignment); + auto* store_instruction = + AlignedStore(load_instruction, BitCast(target, primitive_ptr_type), + element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = b_.CreateMemCpy( + auto* memcpy_instruction = MemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2372,7 +2514,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { - gtl::ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2418,9 +2560,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = b_.CreateLoad( - GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = b_.CreateICmpNE( + llvm::LoadInst* pred_value = + Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); @@ -2446,11 +2588,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } -Status IrEmitter::HandleIota(HloInstruction* iota) { - // TODO(b/64798317): implement iota on CPU. - return Unimplemented("Iota is not implemented on CPU."); -} - Status IrEmitter::HandleRng(HloInstruction* rng) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { @@ -2507,8 +2644,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return b_.CreateGEP(GetProfileCountersArgument(), - b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); + return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), + AsStringRef(counter_name)); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2626,15 +2763,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() { return compute_function_->profile_counters_arg(); } -llvm::Value* IrEmitter::GetTempBuffersArgument() { - return compute_function_->temp_buffers_arg(); +llvm::Value* IrEmitter::GetBufferTableArgument() { + return compute_function_->buffer_table_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { @@ -2662,8 +2799,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); + llvm::LoadInst* param_address_untyped = Load(param_address_offset); if (!ShapeUtil::IsOpaque(target_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); @@ -2683,25 +2819,23 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( auto buf_it = thread_local_buffers_.find(key); if (buf_it == thread_local_buffers_.end()) { llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( - IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, - MinimumAlignmentForShape(target_shape)); + IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()), + &b_, MinimumAlignmentForShape(target_shape)); auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); CHECK(it_inserted_pair.second); buf_it = it_inserted_pair.first; } return buf_it->second; }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( +llvm::Value* IrEmitter::EmitGlobalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &b_); - llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + GetBufferTableArgument(), slice.index(), &b_); + llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { tempbuf_address_base->setMetadata( @@ -2715,20 +2849,20 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( if (slice.offset() > 0) { // Adjust the address to account for the slice offset. tempbuf_address_untyped = - b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); + InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { +llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape) { if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); + return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); } else { - return EmitGlobalTempBufferPointer(slice, target_shape); + return EmitGlobalBufferPointer(slice, target_shape); } } @@ -2736,7 +2870,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { const Shape& target_shape = op->shape(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + llvm::Value* addr = EmitBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2749,7 +2883,7 @@ Status IrEmitter::EmitTargetElementLoop( } Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); @@ -2765,8 +2899,7 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, assignment_.GetUniqueSlice(target_op, {i})); const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); - llvm::Value* op_target_address = - EmitTempBufferPointer(slice, element_shape); + llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape); output_arrays.push_back( llvm_ir::IrArray(op_target_address, element_shape)); } @@ -2804,15 +2937,15 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + MemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, - gtl::ArraySlice operands, - gtl::ArraySlice supported_types) { + absl::Span operands, + absl::Span supported_types) { for (auto operand : operands) { TF_RET_CHECK( ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); @@ -2823,8 +2956,8 @@ Status IrEmitter::ElementTypesSameAndSupported( if (std::find(supported_types.begin(), supported_types.end(), primitive_type) == supported_types.end()) { return Unimplemented("unsupported operand type %s in op %s", - PrimitiveType_Name(primitive_type).c_str(), - HloOpcodeString(instruction.opcode()).c_str()); + PrimitiveType_Name(primitive_type), + HloOpcodeString(instruction.opcode())); } return Status::OK(); } @@ -2842,9 +2975,10 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } llvm::Value* IrEmitter::EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, - tensorflow::StringPiece name) { + const HloComputation& callee, absl::Span parameters, + absl::string_view name) { + CHECK(absl::c_binary_search(thread_local_computations_, &callee)); + const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2859,38 +2993,39 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( CHECK(!parameter->getType()->isPointerTy()); llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); + Store(parameter, parameter_addr); parameter_addrs.push_back(parameter_addr); } llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + absl::StrCat(name, "_retval_addr"), &b_, MinimumAlignmentForPrimitiveType(return_type)); - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*buffer_table_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); - return b_.CreateLoad(return_value_buffer); + return Load(return_value_buffer); } void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + absl::string_view name) { + CHECK(absl::c_binary_search(global_computations_, &callee)); + + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( @@ -2902,7 +3037,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( const BufferAllocation::Slice root_buffer = assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); + return EmitBufferPointer(root_buffer, root_inst->shape()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index c9a1dab62dcbcd926baa82737d24efa03fd326e9..daafef4eb38f14679e025d8e75dd671e94198102 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -39,13 +41,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -55,13 +56,14 @@ namespace cpu { // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: // Create a new LLVM IR emitter. // // hlo_module: the HLO module we are emitting IR for. - // assignment: a BufferAssignment from which we know which temporary buffers - // are used by the HLO nodes. + // assignment: a BufferAssignment from which we know which buffers are used by + // the HLO nodes. // llvm_module: the LLVM module to emit IR into. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. @@ -96,18 +98,21 @@ class IrEmitter : public DfsHloVisitorWithDefault { StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return &b_; } + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); // Emit code to map one element according to `map_instr`. llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, - tensorflow::StringPiece name); + absl::Span elemental_operands, + absl::string_view name); protected: // @@ -152,13 +157,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; - Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); + } + private: // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); @@ -215,31 +225,28 @@ class IrEmitter : public DfsHloVisitorWithDefault { // argument of the computation function being emitted by this emitter. llvm::Value* GetExecutableRunOptionsArgument(); - // Get the llvm::Value* that represents the "temps" argument of the + // Get the llvm::Value* that represents the "buffer_table" argument of the // computation function being emitted by this emitter. - llvm::Value* GetTempBuffersArgument(); + llvm::Value* GetBufferTableArgument(); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + // Helper for EmitBufferPointer. + llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitThreadLocalTempBufferPointer( + // Helper for EmitBufferPointer. + llvm::Value* EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape); // Emits code that computes the address of the given buffer allocation slice. - // - // TODO(sanjoy): This should be renamed to reflect that it no longer provides - // access to just temporaries. - llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); // Emits a function into the current module. This can be used for // computations embedded inside other computations, such as the // function that a map operation applies. StatusOr EmitFunction( HloComputation* function, // The function to emit. - tensorflow::StringPiece + absl::string_view function_name_suffix); // Used for LLVM IR register names. // Emits a call to a thread local function (e.g. to the computation nested @@ -248,17 +255,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { // // `parameters` holds the *scalar values* that need to be passed to the // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, - tensorflow::StringPiece name); + llvm::Value* EmitThreadLocalCall(const HloComputation& callee, + absl::Span parameters, + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to // the parameters and return values for these computations so there is no need // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); + void EmitGlobalCall(const HloComputation& callee, absl::string_view name); // Returns the buffer to which a global call to `callee` would have written // its result. @@ -268,8 +273,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // match and are of one of the given supported types. Status ElementTypesSameAndSupported( const HloInstruction& instruction, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice supported_types); + absl::Span operands, + absl::Span supported_types); // Emit IR to perform a computation for every element in the given target op. // This produces a series of nested loops (one for each dimension of the op's @@ -285,7 +290,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); Status EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator); // Emits a memcpy from the source instruction's result value to the @@ -316,10 +321,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { // concepts that generalize over other vectorizable operations. We should // consider pulling out these abstractions into a VectorizingIrEmitter or // something similar. - StatusOr EmitVectorizedReduce( - HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function, - string* failure_reason); + StatusOr EmitVectorizedReduce(HloInstruction* reduce, + HloInstruction* arg, + HloInstruction* init_value, + absl::Span dimensions, + HloComputation* function, + string* failure_reason); // We'd like to keep one or two one cache-line's worth of data in registers // without generating IR with illegal (e.g. excessively large or @@ -369,16 +376,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment); // Tries to emit a fast concatenate operation using memcpy. Returns true if // successful, and false on failure. On failure, sets "failure_reason" to a // string describing why it could not emit a fast concatenate. - StatusOr EmitFastConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands, - string* failure_reason); + StatusOr EmitFastConcatenate(HloInstruction* concatenate, + absl::Span operands, + string* failure_reason); // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". @@ -387,8 +393,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Assignment of the temporary buffers needed by the computation and their - // shape information. + // Assignment of the buffers needed by the computation and their shape + // information. const BufferAssignment& assignment_; // The LLVM module into which IR will be emitted. @@ -568,6 +574,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::FlatMap constant_buffer_to_global_; + std::vector thread_local_computations_; + std::vector global_computations_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2db4d000f5b149969c88fb4325ca28aa11dc3708..adfb8392bf6fa356f0a5cdab3ff74036eca8918e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -77,19 +78,20 @@ void IrFunction::Initialize(const string& function_name, const bool optimize_for_size_requested, const bool enable_fast_math) { // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // void function(i8* retval, i8* run_options, i8** params, i8** + // buffer_table, // i64* dynamic_loop_bounds, i64* prof_counters) // // For thread local functions: // retval: points to the returned value. // params: address of an array with pointers to parameters. - // temps: is null + // buffer_table: is null // // For global functions: // retval: is null // params: is null - // temps: address of an array with pointers to temporary buffers and entry - // computation parameters. + // buffer_table: address of an array with pointers to temporary buffers and + // entry computation parameters (but not to constant buffers). // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -115,7 +117,7 @@ void IrFunction::Initialize(const string& function_name, // \---------/ \---------/ \-----------/ // // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 | // | addr | addr | | addr | // \---------------------------------------------/ // | | | @@ -133,9 +135,9 @@ void IrFunction::Initialize(const string& function_name, // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // \---------------------------------------------/ - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. + // Even though the type of params and buffer_table is void** in the host's + // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to + // the code to use GEPs to unravel the indirection layers. llvm::FunctionType* function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), /*Params=*/ @@ -159,8 +161,8 @@ void IrFunction::Initialize(const string& function_name, exec_run_options_arg_ = &*arg_iter; (++arg_iter)->setName("params"); parameters_arg_ = &*arg_iter; - (++arg_iter)->setName("temps"); - temp_buffers_arg_ = &*arg_iter; + (++arg_iter)->setName("buffer_table"); + buffer_table_arg_ = &*arg_iter; if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); dynamic_loop_bounds_arg_ = &*arg_iter; @@ -189,7 +191,7 @@ void IrFunction::Initialize(const string& function_name, llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), b_->getInt64(offset), AsStringRef(name))); } @@ -199,10 +201,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // Returns an array of compute function call arguments (including parameter // address buffer). std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, + llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; if (parameter_addresses.empty()) { @@ -211,13 +213,13 @@ std::vector GetArrayFunctionCallArguments( } else { parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); + AsStringRef(absl::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -229,7 +231,7 @@ std::vector GetArrayFunctionCallArguments( }; std::vector arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), - parameter_addresses_buffer, temp_buffers_arg}; + parameter_addresses_buffer, buffer_table_arg}; if (profile_counters_arg != nullptr) { arguments.push_back(profile_counters_arg); } @@ -320,8 +322,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index a41cbb64cdd9f5b6de5d1eadfbf7e63e1e984801..623a5f185fa1fd0526bc8664e2ba11c9dde79b1d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -80,8 +80,9 @@ class IrFunction { // Get the llvm::Value* that represents this functions parameters argument. llvm::Value* parameters_arg() { return parameters_arg_; } - // Get the llvm::Value* that represents this functions "temps" argument. - llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + // Get the llvm::Value* that represents this functions "buffer_table" + // argument. + llvm::Value* buffer_table_arg() { return buffer_table_arg_; } // Get the llvm::Value* that represents this functions "prof_counters" // argument. @@ -108,17 +109,17 @@ class IrFunction { llvm::Argument* result_arg_; llvm::Value* exec_run_options_arg_; llvm::Value* parameters_arg_; - llvm::Value* temp_buffers_arg_; + llvm::Value* buffer_table_arg_; llvm::Value* dynamic_loop_bounds_arg_ = nullptr; llvm::Value* profile_counters_arg_; }; // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, + llvm::Value* profile_counters_arg); // Emits a call to a runtime fork/join function which dispatches parallel // calls to 'parallel_function' (and joins threads before returning). diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 8560e4296aa95fe791446abb1b4363b9145f343e..f8441c3e345504616485c6b34b4302acd5cc23a3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace cpu { @@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { CHECK_NE(index_type, nullptr); CHECK(!ShapeUtil::IsTuple(shape_)); @@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); + /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, + end_index); array_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 076c683ca566f2c53992c358903d2aadead290f9..a604e1db222139c239a2a89359a7359463e0def7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4fa5984b0466b178a587e97cbced97deac749f74..ede7f433ca6b2cc5629115f800348be9dfb2b93b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -109,7 +111,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. - auto cost_analysis = MakeUnique(shape_size); + auto cost_analysis = absl::make_unique(shape_size); HloComputation* computation = module->entry_computation(); Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { @@ -140,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + opcode == HloOpcode::kSort || (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || @@ -216,8 +219,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( // Outline 'instruction' in 'computation' for parallel task assignment. auto* call = module->OutlineExpressionFromComputation( - {instruction}, - tensorflow::strings::StrCat("parallel_", instruction->name()), + {instruction}, absl::StrCat("parallel_", instruction->name()), computation); // Set assigned dimension partitioning to 'instruction'. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 8becc8fa23424d7454cc783eb9d853aecb5d053b..3822d5300e30704f68b2cf0c7f0b77d595c17a25 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -60,7 +60,7 @@ class ParallelTaskAssignment { // own embedded computation, which is compiled as a parallel compute function, // and which is invoked from a kCall instruction that is lowered in codegen to // a runtime parallel fork/join call. -class ParallelTaskAssigner : public HloPassInterface { +class ParallelTaskAssigner : public HloModulePass { public: // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel @@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface { target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cpu-parallel-task-assigner"; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 36c9f743859ae2da6c4fb3fd753bd7862fe2d3ab..fad76338a57cd9eb21d9469ca8552efa8ea0129b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -36,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} @@ -110,9 +109,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + token = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data) + ROOT outfeed0 = token[] outfeed(infeed0.data, token) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index a5f34908d70dd18ec017bdf9833c7df40f80db07..2d9492eacfea34bec3b0f1115e171a5328b7cdc3 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, uint64* prof_counters, int32 num_partitions, + void** buffer_table, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions @@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, nullptr, temps, + function(result_ptr, run_options_ptr, nullptr, buffer_table, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; @@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( } // Call first compute function inline. - function(result_ptr, run_options_ptr, params, temps, &partitions[0], + function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0], prof_counters); VLOG(3) << "ParallelForkJoin partition 0 done."; bc.Wait(); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h index 1cf0ec6e3df400e35fa4e755a0b25b4ce7966e8f..a279c7d2d61bdd138f5285a8c8ccc89d22db9692 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -24,7 +24,7 @@ extern "C" { // threads before returning. See comments in runtime_fork_join.cc for details. extern void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, tensorflow::uint64* prof_counters, + void** buffer_table, tensorflow::uint64* prof_counters, tensorflow::int32 num_partitions, tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims, void* function_ptr); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0e7deb98e579c090c8fae320a3ba8a3ce8dbe5f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -0,0 +1,236 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" + +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace { +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +template +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements); +} + +// For floating point numbers, we want a total order comparator. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. Also we want to have a stable sort, so if the keys are the +// same, we compare the index values. +template +bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { + bool lhs_is_negative = std::signbit(lhs); + bool rhs_is_negative = std::signbit(rhs); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(lhs); + bool rhs_nan = std::isnan(rhs); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; + } + if (lhs != rhs) { + return lhs < rhs; + } + return lhs_index < rhs_index; +} + +template <> +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair* row_to_sort, + int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), lhs.second, + Eigen::half_impl::half_to_float(rhs.first), rhs.second); + }); +} + +template +void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + // High-level idea of the iteration/sorting logic: + // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the + // dimension to sort, c is the product of the more minor dimensions (set to 1 + // if b is the most minor dimension), and a is the product of the more major + // dimensions (set to 1 if b is the most major dimension). There are a * c + // many rows that we need to sort. We iterate through these, calculate a + // 'base_offset' value which points to the first element in that row, and add + // i * c for accessing the 'i'-th element in that row. + + int64 sort_dimension_elements = b; + int64 num_iteration_elements = a * c; + int64 sort_dimension_offset = c; + + std::unique_ptr[]> row_to_sort( + new std::pair[sort_dimension_elements]); + std::unique_ptr reordered_values( + new std::string[sort_dimension_elements]); + for (int64 index = 0; index < num_iteration_elements; ++index) { + // 'index' can be split into two values which index into the 'c' dimension + // and the 'a' dimension, respectively. 'index' % 'c' is the index into the + // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When + // calculating the base offset, we need to multiply the index into the 'a' + // dimension with 'b' * 'c'. + // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'. + int64 base_offset = + index % sort_dimension_offset + + (index - index % sort_dimension_offset) * sort_dimension_elements; + // TODO(b/26783907): We could define a custom iterator class that references + // both arrays. Then we could avoid the intermediate copy. However this + // would become more complicated, and it is not clear if the benefit is high + // enough. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + row_to_sort[i] = + std::make_pair(keys[base_offset + i * sort_dimension_offset], i); + } + KeyValueSort(row_to_sort.get(), sort_dimension_elements); + for (int64 i = 0; i < sort_dimension_elements; ++i) { + keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + } + if (values == nullptr) { + continue; + } + + // Reorder the values according to the order defined by the keys. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = + (base_offset + row_to_sort[i].second * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + + reordered_values[i] = std::string(values + memory_index, + values_primitive_type_size_in_bytes); + } + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = (base_offset + i * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + memcpy(values + memory_index, reordered_values[i].c_str(), + values_primitive_type_size_in_bytes); + } + } +} +} // namespace + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( + int8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( + uint8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( + int16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( + uint16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( + int32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( + uint32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( + float* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( + int64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( + uint64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( + double* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h new file mode 100644 index 0000000000000000000000000000000000000000..28e35e82c18cbf078f8a1e7f5b818bf839d3d3df --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' +// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr. +// If 'values' is not nullptr, the elements in 'values' are reordered in such a +// way that if the element at index 'i' in 'keys' was moved to index 'j', the +// element at index 'i' in 'values' is also moved to index 'j' (which means that +// the same elements correspond to each other as before). +extern void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS8( + tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU8( + tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS16( + tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU16( + tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS32( + tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU32( + tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF32( + float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS64( + tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU64( + tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF64( + double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index 997fdd2ab309f0b68a9dbd0f156a8dc19955b437..8dc5f3c93b6ba1a722ea7b23b4b5190ac0600cd6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index f227e4ae139b92e56786e38ef8eef72c9e2cd424..55d5925642a97b1a0425c092c82070d4b8e59df3 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -37,21 +37,20 @@ int main(int argc, char** argv) { xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); // Transfer parameters. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); @@ -59,17 +58,16 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr> result = - client->ExecuteAndTransfer( - computation, - /*arguments=*/{param0_data.get(), param1_data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/&profile); - std::unique_ptr actual = result.ConsumeValueOrDie(); + xla::StatusOr result = client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*execution_options=*/nullptr, + /*execution_profile=*/&profile); + xla::Literal actual = result.ConsumeValueOrDie(); - LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", - profile.compute_time_ns()); - LOG(INFO) << actual->ToString(); + LOG(INFO) << absl::StrFormat("computation took %dns", + profile.compute_time_ns()); + LOG(INFO) << actual.ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ae80a6f4977f85cfd9f872734fd0a69432a1f382..1a3d82de954318368d61e3feeb0345dc592dcd8b 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloTestBase { +class ShapePartitionAssignerTest : public HloVerifiedTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloTestBase { +class ShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; }; @@ -102,22 +102,22 @@ TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { { ShapePartitionIterator iterator(shape, {1}); EXPECT_EQ(1, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 5}}), iterator.GetPartition(0))); } { ShapePartitionIterator iterator(shape, {2}); EXPECT_EQ(2, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); - EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 2}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(1))); } { ShapePartitionIterator iterator(shape, {3}); EXPECT_EQ(3, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); - EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); - EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{1, 1}}), iterator.GetPartition(1))); + EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(2))); } } @@ -128,24 +128,24 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { ShapePartitionIterator iterator(shape, {1, 1}); EXPECT_EQ(1, iterator.GetTotalPartitionCount()); EXPECT_TRUE( - ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); + absl::c_equal(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); } { ShapePartitionIterator iterator(shape, {2, 2}); EXPECT_EQ(4, iterator.GetTotalPartitionCount()); EXPECT_TRUE( - ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); + absl::c_equal(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); EXPECT_TRUE( - ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); + absl::c_equal(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); EXPECT_TRUE( - ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); + absl::c_equal(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); EXPECT_TRUE( - ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); + absl::c_equal(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); } } -class RandomShapePartitionIteratorTest : public HloTestBase { +class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index be772cfb7e564cebc5725854dbf5678e5c507556..9ec0c8f65705db335379649def746921e6b05bea 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" @@ -170,15 +171,14 @@ namespace { bool RegisterKnownJITSymbols() { CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); -#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ - do { \ - auto* function_address = \ - reinterpret_cast(__xla_cpu_runtime_##base_name); \ - registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ - CHECK_EQ( \ - tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ - "__xla_cpu_runtime_" #base_name); \ +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ } while (false) REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); @@ -203,6 +203,18 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 181cec3cdddeb40daf5276d9d1d6a139417a6072..4b129c95d46d8b5a119e5d23eef387daf7863cce 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,9 +48,11 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +96,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -108,6 +111,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -118,9 +122,11 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -174,3 +180,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cpu_key_value_sort_test", + srcs = ["cpu_key_value_sort_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 6fcce42eaa4599eb8a6dacc1bd39eefd39aa5e50..18ee25ba9158c28baaf01492c290638b9673f1ec 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -19,10 +19,11 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { HloInstruction* rhs = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "input")); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } @@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { HloInstruction* lhs_transposed = builder.AddInstruction( HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index d98856fdbf4165a5909f193ebe8512e21af83dfc..1deb412064b02988a8d4a6d726969c948d354d47 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,15 +17,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloTestBase { +class CpuFusionTest : public HloVerifiedTestBase { protected: CpuFusionTest() {} @@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); - Shape vshape = input_literal1->shape(); + Shape vshape = input_literal1.shape(); auto input1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal1))); @@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, result, error_spec_); } TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -122,20 +122,19 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, - error_spec_); + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, result, error_spec_); } -TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { - // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { + // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { << fusion_instruction2->fused_instructions_computation()->ToString(); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, - *result, error_spec_); + result, error_spec_); } TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { @@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b..5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { // Send 5 Infeed data of shape F32[3]. ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 3 infeed data should be added. - LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(45.0f, result_literal, ErrorSpec{1e-7}); } // Tests two Infeed operations with a total order. The order is enforced by @@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({3, 4}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({5, 6}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8}), + LiteralUtil::CreateR0(false)}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8, 9}), + LiteralUtil::CreateR0(false)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({4, 5, 6}), + LiteralUtil::CreateR0(true)}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 6 infeed data should be added. - LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(66.0f, result_literal, ErrorSpec{1e-7}); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 973aac8766f5aabca15e5173b43480c113c100dd..a434c04a980b9b3cd849792b97a0d9e965ba09f2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android"; struct IntrinsicTestSpec { HloOpcode opcode; - tensorflow::StringPiece triple; - tensorflow::StringPiece features; - tensorflow::StringPiece check_lines; + absl::string_view triple; + absl::string_view features; + absl::string_view check_lines; }; // Tests that unary functions get lowered using intrinsic calls. @@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest features = ""; } - return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), - features.empty() ? "" : "_With", - features.c_str()); + return absl::StrCat(opcode, "_On_", triple, + (features.empty() ? "" : "_With"), features); } }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3934c03a04c978009282b3cd0d39bacf9b12a356 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuKeyValueSortTest : public CpuCodegenTest {}; + +TEST_F(CpuKeyValueSortTest, SortR1) { + const string hlo_text = R"( +HloModule KeyValueSort + +ENTRY main { + a = f32[10] parameter(0) + + ROOT result = f32[10] sort(f32[10] a), dimensions={0} +} +)"; + + string filecheck_pattern = R"( +CHECK: call void @__xla_cpu_runtime_KeyValueSort +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 90b99c828e2fcfd77579026a39d3a6711599feee..3b87683ffffefd2aa24dd234cc072425bef00a24 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - infeed = (pred[], token[]) infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } @@ -50,8 +51,9 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = token[] outfeed(f32[2,3,2] const_a) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b) + token = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) } )"; @@ -85,7 +87,8 @@ while_body { while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - infeed = (pred[], token[]) infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } @@ -94,8 +97,9 @@ ENTRY main { const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) + token = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 01daed4bcd38323bfe33e798a78c2b00b150a1bc..7af51db55af44ae1e437ea8e4de7427012cad82f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {}; TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); @@ -62,7 +61,8 @@ TEST_F(CpuNoAliasTest, Concat) { // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. auto status_or_buffer_assn = BufferAssigner::Run( - hlo_module.get(), MakeUnique(hlo_module.get()), + hlo_module.get(), + absl::make_unique(hlo_module.get()), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return /*alignment=*/1; }); ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index dac416e1c78c2f60d458480c5062f48b77d4878d..e2c7af541eede5265f274c72f55305549f059839 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -32,7 +32,8 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - outfeed = token[] outfeed(f32[2,3,2] const_a) + token = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token) ROOT root = () tuple() } )"; @@ -53,6 +54,33 @@ CHECK: private constant [48 x i8] /*match_optimized_ir=*/false); } +TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) { + const string hlo_text = R"( +HloModule OutfeedTokenInTuple + +ENTRY main { + const = f32[] constant(42) + epoch = token[] after-all() + outfeed.tok = token[] outfeed(const, epoch) + ROOT root = (token[], f32[]) tuple(outfeed.tok, const) +} +)"; + + string filecheck_pattern = R"( +CHECK: Outfeed +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 3274be8d9dbfaa55e250748a389ad34fdeb81922..1bd4b59dd604687589eee061d34aa9ca94f6d700 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "absl/algorithm/container.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -422,12 +423,12 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support, std::vector TileVariable::Get() const { std::vector result; - c_transform(storage_, std::back_inserter(result), - [&](VectorVariable vect_var) { return vect_var.Get(); }); + absl::c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); return result; } -void TileVariable::Set(tensorflow::gtl::ArraySlice value) { +void TileVariable::Set(absl::Span value) { CHECK_EQ(value.size(), storage_.size()); for (int64 i = 0, e = value.size(); i < e; i++) { storage_[i].Set(value[i]); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index c728f6df0aef83e6ddc6c932a347f14da06d9d0d..5690d2be2fe3e21c96b51a5226e0b29148217fd1 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -324,7 +324,7 @@ class TileVariable { std::vector initial_value); std::vector Get() const; - void Set(tensorflow::gtl::ArraySlice value); + void Set(absl::Span value); private: std::vector storage_; diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index 47543b2082f55cf7b8cf60f1c5bbb16a0a609912..b9e47f5aade3334bece28643e6e32ecfce3bf67b 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() { } void XfeedQueueManager::EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers) { + absl::Span buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index b4ace232607e14fbfec01d48946f0031d96cd027..990ff94ba2338cb663b655ca3106bda83ab718a3 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -22,10 +22,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -63,8 +63,7 @@ class XfeedQueueManager { // called when the buffer will no longer be accessed by the XfeedManager, // either as a result of a call to Reset or because the runtime has dequeued // and used the buffer. - void EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers); + void EnqueueBuffersAtomically(absl::Span buffers); // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. Sets the current buffer to be the returned buffer. It is an diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index 8fe65f488a2f0c4031926fa4c5f02dcf5473568d..cc38b81455b5a35cdcd07ac1dfb80cc7b101a7bc 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) { auto shape = ShapeUtil::MakeShape(U8, {length}); string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - length, bytes.data(), bytes.size()); - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer, - bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } // Performs the acquire/release sequence on the outfeed, as the generated CPU @@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) { void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) { string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - length, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - length, buffer, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } TEST_F(InfeedManagerTest, SingleThreadedSequential) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); xfeed->infeed()->EnqueueBuffersAtomically({b}); @@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); ProcessNextBuffer(a->length()); @@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TEST_F(InfeedManagerTest, MultiThreaded) { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); const int32 length = 64; @@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) { TEST_F(InfeedManagerTest, OutfeedWrongShape) { TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->outfeed()->EnqueueBuffersAtomically({b}); ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33})); diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index 56b28fd22da1ea6bc19f98e76f0f2ef4044cd3af..aaa41fc4fe779cdf01a34e86855cac02552ee383 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -25,11 +25,11 @@ namespace xla { // A pass which replaces all fusion instructions with the equivalent un-fused // instructions. -class Defuser : public HloPassInterface { +class Defuser : public HloModulePass { public: Defuser() {} ~Defuser() override {} - tensorflow::StringPiece name() const override { return "defuser"; } + absl::string_view name() const override { return "defuser"; } // Run defusion on the given module. Returns whether the module was // changed. diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4b5bfdd70d5a614b9890b4d7bf050f7..b3549acfc291a54b2345b006310613c3a45a4b47 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,31 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloModulePass { + public: + ControlDepRemover() = default; + absl::string_view name() const override { return "control-dep-remover"; } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index cc1695b7f863805e0b483478639c17cb9061310a..46dcc3a438cbdf3ff1b3c99fa15b35ee7a4e280e 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -30,10 +30,10 @@ namespace xla { // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, // and BFloat16MixedPrecisionRemoval. -class Despecializer : public HloPassInterface { +class Despecializer : public HloModulePass { public: Despecializer(); - tensorflow::StringPiece name() const override { return "despecializer"; } + absl::string_view name() const override { return "despecializer"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index e228bb56bce8febcca28ae171f6de90973d020ab..edbcb25247421cdb50a845df1ec8b1851970efe3 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -25,7 +25,7 @@ namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors) + absl::Span stream_executors) : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} @@ -36,9 +36,8 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( se::DeviceMemoryBase result = stream_executor->AllocateArray(size); if (size > 0 && result == nullptr) { return ResourceExhausted( - "Failed to allocate request for %s (%lluB) on device ordinal %d", - tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, - device_ordinal); + "Failed to allocate request for %s (%uB) on device ordinal %d", + tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); } return OwningDeviceMemory(result, device_ordinal, this); } @@ -61,12 +60,12 @@ StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( } if (device_ordinal >= stream_executors_.size()) { return InvalidArgument( - "device ordinal value (%d) >= number of devices (%zu)", device_ordinal, + "device ordinal value (%d) >= number of devices (%u)", device_ordinal, stream_executors_.size()); } if (stream_executors_[device_ordinal] == nullptr) { return NotFound("Device %s:%d present but not supported", - platform()->Name().c_str(), device_ordinal); + platform()->Name(), device_ordinal); } return stream_executors_[device_ordinal]; } diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index d87b86caf0d3acaa5bf9a455cff2315cedb2496d..a2308ee7a4137bbafe9804c30e33cc68d4628588 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span stream_executors); StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 2172ae0a29626660e8abd29a789e0baa3831519d..3e7373adc5ab8a60fd18348ce2477175aaaa8fd4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -28,14 +28,14 @@ template Status DfsHloVisitorBase::HandleElementwiseUnary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template Status DfsHloVisitorBase::HandleElementwiseBinary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 9f867491258727e0bb53d960af3b977690f8f31a..5761573791d90e45c65b55124a4bae3c5b929ef1 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,14 +19,14 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -106,6 +106,8 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -207,7 +209,6 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; - virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index ae8a066d626fcf6c6670f4994a58f0b8e8027aad..4cd10ab06cd3b804406607212d3f3c316d6cff95 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -94,6 +94,12 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } + Status HandleAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermute(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -103,9 +109,6 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } - Status HandleHostCompute(HloInstructionPtr host_compute) override { - return DefaultAction(host_compute); - } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 12faed69677cd99c6ed82c8d13dad3138d9461b7..b2ba2617902104bfea06713332fa1c2aedea536d 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -134,8 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( - dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + auto dot_r2 = computation->AddInstruction( + HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2, + dot_dnums, dot->precision_config())); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 1959b687f16d6909a3283021c8635b3e65e6e412..40e7a3b4c25ff20674de0cce3fe2907fc43a5cb9 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -23,13 +23,13 @@ namespace xla { // DotDecomposer is a pass which decomposes batch Dot operations into a // sequence of smaller (R2) Dot operations. -class DotDecomposer : public HloPassInterface { +class DotDecomposer : public HloModulePass { public: // Decomposes batch Dot operations when 'decompose_batch_dot' is true. DotDecomposer(bool decompose_batch_dot = true) : decompose_batch_dot_(decompose_batch_dot) {} ~DotDecomposer() = default; - tensorflow::StringPiece name() const override { return "dot_decomposer"; } + absl::string_view name() const override { return "dot_decomposer"; } // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index f05c2d63d2da3f7458c758308a8fc02c3b77af9b..515267edd7caf42e04ebe638b99006db8967ea30 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -21,11 +21,15 @@ limitations under the License. #include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -38,17 +42,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrCat; using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrCat; namespace { @@ -203,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, } // namespace StatusOr ElementalIrEmitter::EmitUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { if (op->opcode() == HloOpcode::kCopy) { return operand_value; } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || @@ -217,7 +220,7 @@ StatusOr ElementalIrEmitter::EmitUnaryOp( } StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -229,14 +232,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( - operand_value->getType(), 0)), + ICmpNE(operand_value, + llvm::ConstantInt::get(operand_value->getType(), 0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsIntegralType(to_type)) { - return b_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(from_type)); + return IntCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_), + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -252,19 +255,17 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, b_->CreateSIToFP(operand_value, to_ir_component_type), - nullptr); + op, SIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, b_->CreateUIToFP(operand_value, to_ir_component_type), - nullptr); + op, UIToFP(operand_value, to_ir_component_type), nullptr); } } return Unimplemented("conversion from primitive type %s to %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -275,14 +276,13 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -292,10 +292,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (is_signed) { auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpSGE(operand_value, zero); - return b_->CreateSelect(cmp, operand_value, - b_->CreateNeg(operand_value)); + auto cmp = ICmpSGE(operand_value, GetZero(type)); + return Select(cmp, operand_value, Neg(operand_value)); } else { return operand_value; } @@ -307,44 +305,37 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( {operand_value->getType()}, b_); } case HloOpcode::kSign: { - bool is_signed = - primitive_util::IsSignedIntegralType(op->shape().element_type()); + CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type())) + << op->shape().element_type(); auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpEQ(operand_value, zero); - if (is_signed) { - auto ashr = - b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); - return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1)); - } else { - return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1)); - } + auto cmp = ICmpEQ(operand_value, GetZero(type)); + auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1); + return Select(cmp, GetZero(type), Or(ashr, 1)); } case HloOpcode::kNegate: - return b_->CreateNeg(operand_value); + return Neg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return b_->CreateZExt( - b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return b_->CreateNot(operand_value); + return Not(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } default: return Unimplemented("unary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -361,8 +352,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( - to_component_type, module_)), + FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } if (from_type == BF16) { @@ -378,26 +369,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateFCmpUNE( - operand_value, - llvm::ConstantFP::get(operand_value->getType(), 0.0)), + FCmpUNE(operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return b_->CreateFPToSI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToSI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return b_->CreateFPToUI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToUI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -408,14 +398,13 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -431,6 +420,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: return EmitSin(op->shape().element_type(), operand_value); + case HloOpcode::kTanh: + return EmitTanh(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -451,11 +442,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(operand_value, zero); - auto olt = b_->CreateFCmpOLT(operand_value, zero); - return b_->CreateSelect( - oeq, zero, - b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + auto oeq = FCmpOEQ(operand_value, zero); + auto olt = FCmpOLT(operand_value, zero); + return Select(oeq, zero, + Select(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { @@ -465,24 +455,24 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + auto not_infinite = FCmpONE(abs_value, infinity); return b_->CreateZExt(not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return b_->CreateFNeg(operand_value); + return FNeg(operand_value); case HloOpcode::kReal: return operand_value; case HloOpcode::kImag: return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType component_type = primitive_util::IsComplexType(input_type) @@ -494,12 +484,11 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -507,14 +496,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = b_->CreateFAdd(a, one); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), - b_->CreateFMul(b, b)); + auto a_plus_one = FAdd(a, one); + auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -528,11 +515,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex(op, - b_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - b_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex( + op, FPCast(EmitExtractReal(operand_value), to_ir_component_type), + FPCast(EmitExtractImag(operand_value), to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -542,8 +527,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), - b_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -554,8 +538,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); - auto imag_result = b_->CreateFMul(exp_a, sin_b); + auto real_result = FSub(FMul(exp_a, cos_b), one); + auto imag_result = FMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -570,14 +554,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), - b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); + return EmitComposeComplex(op, + FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)), + FMul(sin_a, FSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -593,14 +576,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), - b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); + return EmitComposeComplex(op, + FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)), + FMul(cos_a, FSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -628,74 +610,63 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = - b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( - b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); - auto real_num = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = + FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = FMul(cos_b, cos_b); + auto sin_b_sq = FMul(sin_b, sin_b); + auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + FMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = FMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); + FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = b_->CreateFMul( - cos_b_sin_b, - b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); - auto denom = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), - b_->CreateFDiv(imag_num, denom)); + FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = FMul( + cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, FDiv(real_num, denom), + FDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(cplx_abs, zero); + return Select( oeq, EmitComposeComplex(op, zero, zero), - EmitComposeComplex( - op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); + EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), + FDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex(op, - b_->CreateFNeg(EmitExtractReal(operand_value)), - b_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), + FNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType operand_type = op->operand(0)->shape().element_type(); if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || operand_type == PRED) { @@ -710,21 +681,20 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( } StatusOr ElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return b_->CreateFAdd(lhs_value, rhs_value); + return FAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateFSub(lhs_value, rhs_value); + return FSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateFMul(lhs_value, rhs_value); + return FMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return b_->CreateFDiv(lhs_value, rhs_value); + return FDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return b_->CreateFRem(lhs_value, rhs_value); + return FRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -761,66 +731,52 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex(op, - b_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex(op, - b_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))), + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero); + return Select( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), - EmitComposeComplex( - op, - b_->CreateFDiv( - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq), - b_->CreateFDiv( - b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq))); + EmitComposeComplex(op, + FDiv(FAdd(FMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq), + FDiv(FSub(FMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas @@ -830,21 +786,19 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return b_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return b_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -856,68 +810,71 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = b_->CreateFMul(one_half, c); + auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = b_->CreateFNeg(d); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = b_->CreateFMul(one_half, d); - auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), - b_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), - b_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, - llvm::Value* x) const { - if (prim_type != F32) { - // TODO(b/34339814): Implement inverse erf for F64. + llvm::Value* x) { + if (prim_type != F16 && prim_type != F32 && prim_type != F64) { return Unimplemented( "Inverse erf is only implemented for element " - "type F32."); + "types F16, F32 and F64."); + } + + // Upcast half to float. + if (prim_type == F16) { + x = b_->CreateFPExt(x, b_->getFloatTy()); } - auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(b_->getFloatTy(), f); + + auto get_float = [&](const double f) { + return llvm::ConstantFP::get(x->getType(), f); }; - auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, + auto multiply_add = [&](absl::Span coefficients, llvm::Value* w) { - llvm::Value* p = getFloat(coefficients.front()); - coefficients.pop_front(); + llvm::Value* p = get_float(coefficients.front()); + coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), get_float(coefficient)); } return p; }; // Approximation for inverse error function from // Giles, M., "Approximating the erfinv function". - // The approximation has the form: - // w = log((1-x)*(1+x)) + // The approximation has the form (float version): + // w = -log((1-x)*(1+x)) // if ( w < 5 ) { // w = w - 2.5 // p = sum_{i=1}^n lq[i]*w^i @@ -927,105 +884,179 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {b_->getFloatTy()}); + module_, llvm::Intrinsic::log, {x->getType()}); - llvm::Value* w = b_->CreateFNeg(b_->CreateCall( - logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), - b_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg(Call( + logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))})); llvm::Value* p_addr = - llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); + llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_); + + if (prim_type == F16 || prim_type == F32) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_); + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(2.5f)); + absl::Span lq{ + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + llvm::Value* p = multiply_add(lq, lw); + Store(p, p_addr); + } - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); - // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, b_); - { - llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); - tensorflow::gtl::ArraySlice lq{ - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - llvm::Value* p = multiply_add(lq, lw); - b_->CreateStore(p, p_addr); - } + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f)); + absl::Span gq{ + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + llvm::Value* p = multiply_add(gq, gw); + Store(p, p_addr); + } - // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, b_); - { - llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - - llvm::Value* gw = - b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); - tensorflow::gtl::ArraySlice gq{ - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - llvm::Value* p = multiply_add(gq, gw); - b_->CreateStore(p, p_addr); - } + SetToFirstInsertPoint(if_data.after_block, b_); + } else { + DCHECK(prim_type == F64); + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_); + + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(3.125)); + absl::Span c{ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356}; + llvm::Value* p = multiply_add(c, lw); + Store(p, p_addr); + } - SetToFirstInsertPoint(if_data.after_block, b_); - llvm::Value* p = b_->CreateLoad(p_addr); - return b_->CreateFMul(p, x); + SetToFirstInsertPoint(if_data.false_block, b_); + llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_); + SetToFirstInsertPoint(if_data_second.true_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25)); + absl::Span t1{ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635}; + llvm::Value* p = multiply_add(t1, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data_second.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0)); + absl::Span t2{ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221}; + llvm::Value* p = multiply_add(t2, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data.after_block, b_); + } + llvm::Value* p = Load(p_addr); + x = FMul(p, x); + // Trunc back to half if needed. + if (prim_type == F16) { + x = b_->CreateFPTrunc(x, b_->getHalfTy()); + } + return x; } -StatusOr ElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) { // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, b_->CreateFSub(one, value)); + return EmitErfInv(prim_type, FSub(one, value)); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); auto negative_half = llvm::ConstantFP::get(type, -0.5); // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. - TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, b_->CreateFAdd(x, one))); + TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = - b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); + auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( + auto x_is_small = FCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1033,35 +1064,40 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = b_->CreateFSub(exp_x, one); + auto for_large_x = FSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = b_->CreateFAdd(x, x); - auto x_squared_over_two = b_->CreateFMul(x_squared, half); - auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); + auto x_squared = FAdd(x, x); + auto x_squared_over_two = FMul(x_squared, half); + auto for_small_x = FAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( - abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + auto x_is_small = + FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, {lhs->getType()}, b_); } StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return Unimplemented("atan2"); } +StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { + return Unimplemented("tanh"); +} + StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { + const HloInstruction* hlo, llvm::Value* x) { if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } @@ -1092,23 +1128,103 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } +llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast(type), 1); +} + +llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast(type), 0); +} + +llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) { + auto* integer_type = llvm::cast(type); + return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue( + integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) { + auto* integer_type = llvm::cast(type); + return llvm::ConstantInt::get( + integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) { + return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); +} + +llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs, + llvm::Value* rhs) { + return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())), + ICmpEQ(rhs, GetMinusOne(rhs->getType()))); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer division overflow behavior: + // + // X / 0 == -1 + // INT_SMIN /s -1 = INT_SMIN + + if (!is_signed) { + llvm::Value* udiv_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = UDiv(lhs, safe_rhs); + return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = SDiv(lhs, safe_rhs); + + return Select( + has_zero_divisor, GetMinusOne(lhs->getType()), + Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div)); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer remainder overflow behavior: + // + // X % 0 == X + // INT_SMIN %s -1 = 0 + + if (!is_signed) { + llvm::Value* urem_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = URem(lhs, safe_rhs); + return Select(urem_is_unsafe, lhs, safe_rem); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = SRem(lhs, safe_rhs); + + return Select( + has_zero_divisor, lhs, + Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem)); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return b_->CreateAdd(lhs_value, rhs_value); + return Add(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateSub(lhs_value, rhs_value); + return Sub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateMul(lhs_value, rhs_value); + return Mul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) - : b_->CreateUDiv(lhs_value, rhs_value); + return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: - return is_signed ? b_->CreateSRem(lhs_value, rhs_value) - : b_->CreateURem(lhs_value, rhs_value); + return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, rhs_value, b_); @@ -1136,11 +1252,11 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return b_->CreateAnd(lhs_value, rhs_value); + return And(lhs_value, rhs_value); case HloOpcode::kOr: - return b_->CreateOr(lhs_value, rhs_value); + return Or(lhs_value, rhs_value); case HloOpcode::kXor: - return b_->CreateXor(lhs_value, rhs_value); + return Xor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1149,43 +1265,43 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( // UB. case HloOpcode::kShiftRightArithmetic: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateAShr(lhs_value, rhs_value), + AShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateShl(lhs_value, rhs_value), + Shl(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateLShr(lhs_value, rhs_value), + LShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE - : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE - : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const { + int64 operand_no) { CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() << " is not elementwise."; @@ -1226,7 +1342,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) { TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, @@ -1239,22 +1355,30 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // Convert raw integer to float in range [0, 1) if the element is a float. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - elem_value = b_->CreateUIToFP(elem_value, elem_ir_ty); unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - elem_value = b_->CreateFDiv( - elem_value, - llvm::ConstantFP::get(elem_ir_ty, - raw_value_size_in_bits == 64 ? 0x1p64 : 0x1p32)); + // Perform the division using the float type with the same number of bits + // as the raw value to avoid overflow. + if (raw_value_size_in_bits == 32) { + elem_value = UIToFP(elem_value, b_->getFloatTy()); + elem_value = FDiv(elem_value, + llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + } else { + elem_value = UIToFP(elem_value, b_->getDoubleTy()); + elem_value = FDiv( + elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); + } + + if (elem_ir_ty != elem_value->getType()) { + elem_value = FPTrunc(elem_value, elem_ir_ty); + } } // Convert the value for the requested distribution. switch (hlo->random_distribution()) { case RNG_UNIFORM: { if (elem_ir_ty->isFloatingPointTy()) { - return b_->CreateFAdd( - b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), - a_or_mean); + return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean); } else { // To generate a uniform random value in [a, b) from a raw random sample // in range [0, 2^N), we let range = b - a and return @@ -1267,22 +1391,21 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // the same cost as if the whole warp were to re-sample. So an // efficient re-sampling implementation on GPU would need to do // nontrivial work to share entropy between threads in the warp. - auto range = b_->CreateSub(b_or_sigma, a_or_mean); - return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + auto range = Sub(b_or_sigma, a_or_mean); + return Add(a_or_mean, URem(elem_value, range)); } } case RNG_NORMAL: { TF_ASSIGN_OR_RETURN( llvm::Value * r, - EmitErfcInv(elem_prim_ty, - b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), - elem_value))); - return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return FAdd(FMul(r, b_or_sigma), a_or_mean); } default: return InvalidArgument( "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + RandomDistribution_Name(hlo->random_distribution())); } } @@ -1397,8 +1520,7 @@ std::array CalculateSampleValues( // Precondition: the RNG instruction is not fused. llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { VLOG(3) << "Using philox RNG algorithm"; CHECK(!hlo->IsFused()); // A random number generated by the per module random number generator. @@ -1421,7 +1543,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Load the global state variable for the Philox RNG algorithm. llvm::GlobalVariable* rng_state_ptr = llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); - llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value"); // Build and return the elemental IR generator to generate a random value for // the element corresponding to the current thread. @@ -1447,8 +1569,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // element within the sample. llvm::Value* elems_per_sample_value = llvm::ConstantInt::get(index_ty, elems_per_sample); - llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); - llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value); std::array counter_values = CalculateSampleValues( sample_idx, hlo_random_value, global_random_number, rng_state, b_); @@ -1456,18 +1578,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Store the four counter_values into the sample_address alloca so we can // load the elem_offset'th one below. for (int idx = 0; idx < 4; ++idx) { - b_->CreateStore(counter_values[idx], - b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + Store(counter_values[idx], + InBoundsGEP(sample_address, b_->getInt32(idx))); } llvm::Type* int64_ty = b_->getInt64Ty(); CHECK(elems_per_sample == 2 || elems_per_sample == 4); llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; // Retrieve the raw value for the current element from the current sample. - llvm::Value* raw_elem_value = b_->CreateLoad( - b_->CreateInBoundsGEP( - b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), - elem_offset), + llvm::Value* raw_elem_value = Load( + InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), "raw_elem_value"); return ConvertValueForDistribution(hlo, operand_to_generator, index, @@ -1478,7 +1599,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( StatusOr ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1488,14 +1609,14 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()), - on_true_value, on_false_value); + return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, + on_false_value); } StatusOr ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1514,14 +1635,14 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); } else { return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); + PrimitiveType_Name(prim_type)); } } StatusOr ElementalIrEmitter::EmitElementalConcatenate( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const { + const llvm_ir::IrArray::Index& target_index) { const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; @@ -1543,9 +1664,9 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( } llvm_ir::SetToFirstInsertPoint(exit_block, b_); - llvm::PHINode* output = b_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - hlo->operands().size()); + llvm::PHINode* output = + PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = b_->GetInsertPoint(); b_->SetInsertPoint(init_block); @@ -1560,9 +1681,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - b_->CreateCondBr( - b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), - true_block, false_block); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, + false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. @@ -1575,11 +1695,10 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = - b_->CreateSub(source_index[concat_dim], concat_dim_size); + source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); } - b_->CreateUnreachable(); + Unreachable(); b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1587,7 +1706,7 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); @@ -1604,7 +1723,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); int64 largest_valid_start_index = input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); CHECK_GE(largest_valid_start_index, 0); @@ -1624,7 +1743,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = Add(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1632,7 +1751,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( StatusOr ElementalIrEmitter::EmitElementalGather( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const Shape& operand_shape = hlo->operand(0)->shape(); const Shape& indices_shape = hlo->operand(1)->shape(); const Shape& output_shape = hlo->shape(); @@ -1655,22 +1774,21 @@ StatusOr ElementalIrEmitter::EmitElementalGather( std::vector operand_to_output_dim(operand_shape.dimensions_size(), -1); for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - int64 output_window_dim = - dim_numbers.output_window_dims(operand_index_dim++); + int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); operand_to_output_dim[i] = output_window_dim; operand_index.push_back(index[output_window_dim]); } } - // This is the index of the index vector in the gather_indices tensor. + // This is the index of the index vector in the start_indices tensor. IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { gather_index_index.push_back(index[i]); } } @@ -1682,8 +1800,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = - b_->CreateSExtOrTrunc(index_component, index_type); - int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + SExtOrTrunc(index_component, index_type); + int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. // This means we set the iteration index to 0, so for the purpose of the @@ -1706,8 +1824,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = b_->CreateAdd( - operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_index[operand_dim] = + Add(operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { @@ -1731,7 +1849,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const HloInstruction* input_hlo = hlo->operand(0); const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); @@ -1754,7 +1872,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); int64 largest_valid_start_index = @@ -1770,14 +1888,14 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + + slice_intersection = + And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = + And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); } // Emit: @@ -1794,26 +1912,26 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = Sub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - b_->CreateStore(true_value, ret_value_addr); + Store(true_value, ret_value_addr); // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - b_->CreateStore(false_value, ret_value_addr); + Store(false_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const { + const llvm_ir::IrArray::Index& padded_index) { auto index = padded_index; llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { @@ -1821,26 +1939,22 @@ StatusOr ElementalIrEmitter::EmitElementalPad( return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = - b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = b_->CreateAnd(in_bounds, - b_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = b_->CreateAnd( + index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = + And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); + in_bounds = And( in_bounds, - b_->CreateICmpEQ( + ICmpEQ( index_typed_const(0), - b_->CreateURem(index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = b_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), + URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); + index[i] = + SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = + And(in_bounds, + ICmpSLT(index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); } // if (in_bounds) { @@ -1856,26 +1970,26 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - b_->CreateStore(operand_value, ret_value_addr); + Store(operand_value, ret_value_addr); SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(padding_value, ret_value_addr); + Store(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalDot( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const { + const llvm_ir::IrArray::Index& dot_result_index) { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); @@ -1903,8 +2017,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); - b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); @@ -1926,42 +2039,37 @@ StatusOr ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = b_->CreateFSub( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); - llvm::Value* product_imag = b_->CreateFAdd( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); - next_accumulator = b_->CreateInsertValue( + llvm::Value* product_real = + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = InsertValue( current_accumulator, - b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), - {0}); - next_accumulator = b_->CreateInsertValue( + FAdd(EmitExtractReal(current_accumulator), product_real), {0}); + next_accumulator = InsertValue( next_accumulator, - b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), - {1}); + FAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = b_->CreateFAdd(current_accumulator, - b_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value)); } else { - next_accumulator = - b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); + next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value)); } - b_->CreateStore(next_accumulator, accumulator_alloca); + Store(next_accumulator, accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); - return b_->CreateLoad(accumulator_alloca); + return Load(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -2055,10 +2163,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = b_->CreateSub( - llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_index[dim] = + Sub(llvm::ConstantInt::get(target_index[dim]->getType(), + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } return operand_to_generator.at(operand)(source_index); }; @@ -2072,6 +2180,61 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), hlo->dimensions(), b_)); }; + case HloOpcode::kIota: + return [this, hlo]( + const IrArray::Index& target_index) -> StatusOr { + auto* iota = Cast(hlo); + PrimitiveType element_type = iota->shape().element_type(); + IrArray::Index elem_index = + ShapeUtil::Rank(iota->shape()) > 1 + ? target_index.SourceIndexOfBroadcast( + iota->shape(), + ShapeUtil::MakeShapeWithDescendingLayout( + element_type, + {iota->shape().dimensions(iota->iota_dimension())}), + {iota->iota_dimension()}, b_) + : target_index; + llvm::Value* elem_index_linear = elem_index.linear(); + if (elem_index_linear == nullptr) { + std::vector iota_bound = { + iota->shape().dimensions(iota->iota_dimension())}; + elem_index_linear = elem_index.Linearize(iota_bound, b_); + } + Shape component_shape = + ShapeUtil::ElementIsComplex(iota->shape()) + ? ShapeUtil::ComplexComponentShape(iota->shape()) + : iota->shape(); + PrimitiveType component_element_type = component_shape.element_type(); + llvm::Value* iota_result; + if (ShapeUtil::ElementIsIntegral(component_shape)) { + iota_result = b_->CreateIntCast( + elem_index_linear, + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), + /*isSigned=*/false); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + << component_element_type; + llvm::Type* float_ir_type; + if (component_element_type == BF16) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_); + } + llvm::Value* float_val = + b_->CreateUIToFP(elem_index_linear, float_ir_type); + if (component_element_type == BF16) { + iota_result = EmitF32ToBF16(float_val, b_); + } else { + iota_result = float_val; + } + } + if (ShapeUtil::ElementIsComplex(iota->shape())) { + return EmitComposeComplex(iota, iota_result, nullptr); + } else { + return iota_result; + } + }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2137,28 +2300,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); }; } } -llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return b_->CreateExtractValue(value, {0}); +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) { + return ExtractValue(value, {0}); } -llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return b_->CreateExtractValue(value, {1}); +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { + return ExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const { + llvm::Value* imag) { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = b_->CreateInsertValue( - llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + auto complex = + InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = b_->CreateInsertValue(complex, imag, {1}); + complex = InsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index fcb34557a52d35ef30a5dee643171e17407d05c2..d3e2acaabd4f602171def70ccd3d4fd5adce0d0d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -23,12 +23,13 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { -class ElementalIrEmitter { +class ElementalIrEmitter : public IrBuilderMixin { public: using HloToElementGeneratorMap = std::unordered_map; @@ -40,97 +41,114 @@ class ElementalIrEmitter { virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value) const; + llvm::Value* operand_value); virtual StatusOr EmitBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Returns a function to generate an element of the output of `hlo`, given a // map of functions to generate elements of its operands. virtual llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); - llvm::IRBuilder<>* b() const { return b_; } - llvm::Module* module() const { return module_; } + llvm::IRBuilder<>* b() { return b_; } + + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return b_; } + + llvm::Module* module() { return module_; } protected: - virtual StatusOr EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitIntegerUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); + + virtual StatusOr EmitFloatUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitComplexUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + llvm::Value* IsZero(llvm::Value* v); + llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* GetZero(llvm::Type* type); + llvm::Value* GetOne(llvm::Type* type); + llvm::Value* GetIntSMin(llvm::Type* type); + llvm::Value* GetMinusOne(llvm::Type* type); + + llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); + llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); - virtual StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitComplexBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); virtual StatusOr EmitErfInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); + + virtual StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value); virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, - llvm::Value* x) const; + llvm::Value* x); - virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; - virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value); + virtual llvm::Value* EmitExtractImag(llvm::Value* value); // Composes a complex struct. imag may be nullptr for simple cast operations. llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; + llvm::Value* imag); // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its @@ -139,50 +157,50 @@ class ElementalIrEmitter { // Precondition: `hlo` is an elementwise op. llvm_ir::IrArray::Index ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const; + int64 operand_no); // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } + virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } StatusOr EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalClamp( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalConcatenate( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const; + const llvm_ir::IrArray::Index& target_index); StatusOr EmitElementalDynamicSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalGather( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalPad( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const; + const llvm_ir::IrArray::Index& padded_index); StatusOr EmitElementalDot( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const; + const llvm_ir::IrArray::Index& dot_result_index); llvm::IRBuilder<>* const b_; @@ -197,13 +215,13 @@ class ElementalIrEmitter { // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index addb016b0481b744ff42ba827104099b6cdc3bb9..852f34e06df35242b13110ae4411b8c969c26019 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -24,12 +24,11 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class ElementalIrEmitterExecutionTest : public HloTestBase { protected: - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -57,9 +56,9 @@ ENTRY main { } )"; - std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); - RunTest(hlo_text, {lhs.get(), rhs.get()}); + Literal lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + Literal rhs = LiteralUtil::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {&lhs, &rhs}); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index fd75847d0c0e737957401b8efc420d504a3c0706..47c56e2f7fbd9f53be6a2b189c5c36cf4fdcdccb 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -22,16 +24,14 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -using tensorflow::gtl::ArraySlice; namespace xla { StatusOr> Executable::ExecuteOnStreams( - ArraySlice run_options, - ArraySlice> arguments) { + absl::Span run_options, + absl::Span> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); std::vector return_values; @@ -62,7 +62,7 @@ StatusOr> Executable::ExecuteOnStreams( StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - ArraySlice arguments) { + absl::Span arguments) { se::Stream* stream = run_options->stream(); std::unique_ptr timer; if (profile != nullptr) { @@ -76,8 +76,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer_data(), - &hlo_profile_index_map()) + ? absl::make_unique(&hlo_profile_printer_data(), + &hlo_profile_index_map()) : nullptr; StatusOr return_value = @@ -154,9 +154,9 @@ Status Executable::DumpHloSnapshot() { const string& directory_path = module_config().debug_options().xla_dump_executions_to(); const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", module.id(), - module.entry_computation_name().c_str(), ++execution_count_); + string filename = + absl::StrFormat("computation_%d__%s__execution_%d", module.id(), + module.entry_computation_name(), ++execution_count_); return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 98eaeee30a693211ae564a5ef3c373f0364bef11..3a6780f2a67f230cae626ea00cfbf93b4e60d968 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -18,7 +18,10 @@ limitations under the License. #include #include +#include +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -26,18 +29,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" namespace xla { +// ExecutionOutput encapsulates the output buffers of a execution and the +// leftover buffers to be released by the caller. +struct ExecutionOutput { + ExecutionOutput(ScopedShapedBuffer result, + std::vector to_be_released) + : result(std::move(result)), to_be_released(std::move(to_be_released)) {} + ScopedShapedBuffer result; + + // Leftover buffers for the caller to release. Elements in this list are + // donated input memory buffers that are not reused by XLA as outputs. + std::vector to_be_released; +}; + // A given platform's compiler will produce an Executable -- this is a uniform // interface that is used for launching compiled programs across platforms. class Executable { @@ -63,25 +81,46 @@ class Executable { // Returns a shaped buffer containing the result of the computation. virtual StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) = 0; + absl::Span arguments) = 0; + + // Starts the given program executing on the given stream/executor. + // + // `arguments` are ShapeTree containing the input parameters. For each element + // in the shape tree, if the element holds the ownership of the memory, it is + // considered donated and XLA will potentially reuse it as output buffers. For + // all donated inputs, XLA is also responsible for freeing them. + // + // If an input is donated to XLA but is not reused as output, it is returned + // as an leftover buffer for the caller to release. + virtual StatusOr ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); + } + + virtual StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); + } // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. virtual StatusOr> ExecuteOnStreams( - tensorflow::gtl::ArraySlice - run_options, - tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> - arguments); + absl::Span run_options, + absl::Span> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any // Execute* API call that takes a hlo_execution_profile argument, but must be @@ -97,7 +136,7 @@ class Executable { // given ExecutionProfile if non-null. StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 228c3fac95c3114484637bd93ec51c60b44403cc..997db7c058af6da8ecff399769b85b803e2e5785 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend, tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( - handle, - MakeUnique(backend, std::move(streams), profile, result)); + handle, absl::make_unique(backend, std::move(streams), + profile, result)); CHECK(inserted.second); ExecutionHandle execution_handle; @@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } handle_to_execution_.erase(handle.handle()); @@ -78,7 +78,7 @@ StatusOr ExecutionTracker::Resolve( tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } return it->second.get(); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index d3efab3614912e4b0c2c8aa3b80277c326382ed0..986970f8862472d1db7564254a9c1277750bb6eb 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -26,9 +26,9 @@ namespace xla { // Flattening associates each call site with a unique computation (for // sequential calling contexts) This simplifies buffer assignment and // points-to analysis (see b/36865746 for details). -class FlattenCallGraph : public HloPassInterface { +class FlattenCallGraph : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + absl::string_view name() const override { return "flatten-call-graph"; } // Duplicates computations called from multiple call- or while-nodes to // flatten the call graph. diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8f6608241ed02bbb7e9fde9b6d767c002435e777..5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); + std::unique_ptr flat_call_graph = CallGraph::Build(module); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index e3a42d0d06be9e4c9ef96ed2e6ff5daa8eebaf3e..cb86c9857936f21d9d2ac6bc22c725b89cca6482 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -24,88 +25,87 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::ArraySlice; static StatusOr TransposeIndexVectorDimToLast( - HloInstruction* gather_indices, int64 index_vector_dim) { - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices, int64 index_vector_dim) { + const Shape& start_indices_shape = start_indices->shape(); - if (gather_indices_shape.dimensions_size() == index_vector_dim) { - return gather_indices; + if (start_indices_shape.dimensions_size() == index_vector_dim) { + return start_indices; } - if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { - return gather_indices; + if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) { + return start_indices; } std::vector permutation; - permutation.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + permutation.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { permutation.push_back(i); } } permutation.push_back(index_vector_dim); - return MakeTransposeHlo(gather_indices, permutation); + return MakeTransposeHlo(start_indices, permutation); } -// Canonicalizes the gather_indices tensors so that we only have deal with some +// Canonicalizes the start_indices tensors so that we only have deal with some // specific cases in the while loop that does the heavy lifting. // // See the "High Level Algorithm" section for a broader picture. static StatusOr CanonicalizeGatherIndices( - HloInstruction* gather_indices, int64 index_vector_dim) { + HloInstruction* start_indices, int64 index_vector_dim) { // Transpose the non-index-vector dimensions to the front. TF_ASSIGN_OR_RETURN( - HloInstruction * transposed_gather_indices, - TransposeIndexVectorDimToLast(gather_indices, index_vector_dim)); + HloInstruction * transposed_start_indices, + TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); bool indices_are_scalar = - index_vector_dim == gather_indices->shape().dimensions_size(); + index_vector_dim == start_indices->shape().dimensions_size(); - // The number of dimensions in gather_indices that are index dimensions. - const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1; + // The number of dimensions in start_indices that are index dimensions. + const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1; - // If there is only one index (i.e. gather_indices has rank 1 and this gather + // If there is only one index (i.e. start_indices has rank 1 and this gather // is really just a dynamic slice) add a leading degenerate dimension for // uniformity. Otherwise create a "collapsed" leading dimension that subsumes // all of the non-index-vector dimensions. - const Shape& shape = transposed_gather_indices->shape(); - if (shape.dimensions_size() == index_dims_in_gather_indices) { - return PrependDegenerateDims(transposed_gather_indices, 1); + const Shape& shape = transposed_start_indices->shape(); + if (shape.dimensions_size() == index_dims_in_start_indices) { + return PrependDegenerateDims(transposed_start_indices, 1); } else { - // Collapse all but the dimensions (0 or 1) in gather_indices containing the + // Collapse all but the dimensions (0 or 1) in start_indices containing the // index vectors. return CollapseFirstNDims( - transposed_gather_indices, - shape.dimensions_size() - index_dims_in_gather_indices); + transposed_start_indices, + shape.dimensions_size() - index_dims_in_start_indices); } } // Expands out or contracts away the gather dimensions in the accumulator // produced by the while loop. -static StatusOr AdjustGatherDimsInAccumulator( - const Shape& gather_indices_shape, HloInstruction* accumulator, +static StatusOr AdjustBatchDimsInAccumulator( + const Shape& start_indices_shape, HloInstruction* accumulator, int64 index_vector_dim) { - std::vector output_gather_dim_bounds; - output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + std::vector batch_dim_bounds; + batch_dim_bounds.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { - output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + batch_dim_bounds.push_back(start_indices_shape.dimensions(i)); } } - if (output_gather_dim_bounds.empty()) { - // If output_gather_dim_bounds is empty we must be lowering a (effectively) + if (batch_dim_bounds.empty()) { + // If batch_dim_bounds is empty we must be lowering a (effectively) // dynamic-slice. In that case, there is a leading degenerate gather // dimension that we added to make this special case play well with the // general while loop which we need to remove now. return ElideDegenerateDims(accumulator, {0}); } - return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); + return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds); } -// Expand an index vector from the gather_indices tensor into a vector that can +// Expand an index vector from the start_indices tensor into a vector that can // be used to dynamic-slice out of the gather operand. static StatusOr ExpandIndexVectorIntoOperandSpace( HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, @@ -121,10 +121,8 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( std::vector expanded_index_components; for (int i = 0; i < operand_rank; i++) { - int64 index_vector_dim_index = - FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); - if (index_vector_dim_index != - dim_numbers.gather_dims_to_operand_dims_size()) { + int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i); + if (index_vector_dim_index != dim_numbers.start_index_map_size()) { TF_ASSIGN_OR_RETURN( HloInstruction * component_to_concat, MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, @@ -147,10 +145,10 @@ static StatusOr> GatherLoopBody( const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); CHECK_EQ(incoming_loop_state.size(), 3); HloInstruction* const operand = incoming_loop_state[0]; - HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const start_indices = incoming_loop_state[1]; HloInstruction* const output_accumulator = incoming_loop_state[2]; - bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1; + bool has_scalar_indices = start_indices->shape().dimensions_size() == 1; CHECK_EQ(has_scalar_indices, dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); @@ -163,24 +161,24 @@ static StatusOr> GatherLoopBody( HloInstruction* index_vector; if (has_scalar_indices) { - // In this case gather_indices has rank 1 and induction_var_as_vector (of + // In this case start_indices has rank 1 and induction_var_as_vector (of // shape {1}) is an index into this rank 1 tensor. TF_ASSIGN_OR_RETURN( index_vector, - MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1})); + MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1})); } else { - // In this case gather_indices has rank 2 and induction_var_as_vector (of + // In this case start_indices has rank 2 and induction_var_as_vector (of // shape {1}) is an index into just the first dimension of this rank 2 // tensor. TF_ASSIGN_OR_RETURN( - HloInstruction * index_into_gather_indices, + HloInstruction * index_into_start_indices, PadVectorWithZeros(induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); - int64 index_vector_size = gather_indices->shape().dimensions(1); + int64 index_vector_size = start_indices->shape().dimensions(1); TF_ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, - MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + MakeDynamicSliceHlo(start_indices, index_into_start_indices, {1, index_vector_size})); TF_ASSIGN_OR_RETURN(index_vector, @@ -194,26 +192,26 @@ static StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, MakeDynamicSliceHlo(operand, gathered_slice_start, - gather.gather_window_bounds())); + gather.gather_slice_sizes())); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_with_dims_elided, + HloInstruction* const gathered_slice_with_dims_collapsed, ElideDegenerateDims(gathered_slice, - AsInt64Slice(dim_numbers.elided_window_dims()))); + AsInt64Slice(dim_numbers.collapsed_slice_dims()))); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_for_update, - PrependDegenerateDims(gathered_slice_with_dims_elided, 1)); + HloInstruction* const gathered_slice_for_update, + PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1)); TF_ASSIGN_OR_RETURN( - HloInstruction * index_vector_into_accumulator, + HloInstruction* const index_vector_into_accumulator, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/ - gathered_slice_with_dims_elided->shape().dimensions_size())); + gathered_slice_with_dims_collapsed->shape().dimensions_size())); TF_ASSIGN_OR_RETURN( - HloInstruction * updated_accumulator, + HloInstruction* const updated_accumulator, MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, index_vector_into_accumulator)); @@ -221,19 +219,19 @@ static StatusOr> GatherLoopBody( // WhileUtil::MakeCountedLoop functions takes care of the induction variable // and the while loop exit condition. return StatusOr>{ - {operand, gather_indices, updated_accumulator}}; + {operand, start_indices, updated_accumulator}}; } static StatusOr CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice window_bounds, int64 gather_loop_trip_count, + absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { std::vector accumulator_state_shape_dims; - accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); - for (int64 i = 0; i < window_bounds.size(); i++) { - if (!c_binary_search(dim_numbers.elided_window_dims(), i)) { - accumulator_state_shape_dims.push_back(window_bounds[i]); + for (int64 i = 0; i < slice_sizes.size(); i++) { + if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + accumulator_state_shape_dims.push_back(slice_sizes[i]); } } return BroadcastZeros(computation, element_type, @@ -241,23 +239,23 @@ static StatusOr CreateGatherLoopAccumulatorInitValue( } // `accumulator` is almost the tensor the gather operation would have produced, -// except that it has the dimensions in the wrong order -- the gather dimensions -// are the major dimensions and the window dimensions are the minor dimensions. +// except that it has the dimensions in the wrong order -- the batch dimensions +// are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. -static StatusOr PermuteGatherAndWindowDims( - HloInstruction* accumulator, ArraySlice output_window_dims, +static StatusOr PermuteBatchAndOffsetDims( + HloInstruction* accumulator, absl::Span offset_dims, int64 output_rank) { std::vector permutation; permutation.reserve(output_rank); - int64 gather_idx_counter = 0; - int64 window_idx_counter = output_rank - output_window_dims.size(); + int64 batch_idx_counter = 0; + int64 offset_idx_counter = output_rank - offset_dims.size(); for (int64 i = 0; i < output_rank; i++) { - bool is_window_dim = c_binary_search(output_window_dims, i); - if (is_window_dim) { - permutation.push_back(window_idx_counter++); + bool is_offset_dim = absl::c_binary_search(offset_dims, i); + if (is_offset_dim) { + permutation.push_back(offset_idx_counter++); } else { - permutation.push_back(gather_idx_counter++); + permutation.push_back(batch_idx_counter++); } } @@ -268,11 +266,11 @@ static StatusOr PermuteGatherAndWindowDims( // // We follow the following steps in sequence: // -// 1. We canonicalize the gather_indices tensor such that it has rank +// 1. We canonicalize the start_indices tensor such that it has rank // 2 (i.e. is a matrix) where each row is an index vector into the // operand. // 2. We iterate over the set of indices in the canonicalized -// gather_indices tensor using a while loop, accumulating slices +// start_indices tensor using a while loop, accumulating slices // of the operand tensor into an accumulator using // DynamicUpdateSlice. // 3. The accumulator result from the while loop from (2) is then @@ -287,11 +285,11 @@ static StatusOr PermuteGatherAndWindowDims( // operand = s32[3,3] parameter(0) // indices = s32[2,2] parameter(1) // ROOT gather = s32[2,3,2] gather(operand, indices), -// output_window_dims={1}, -// elided_window_dims={1}, -// gather_dims_to_operand_dims={1}, +// offset_dims={1}, +// collapsed_slice_dims={1}, +// start_index_map={1}, // index_vector_dim=2, -// window_bounds={3, 1} +// slice_sizes={3, 1} // } // // We'd first reshape indices to s32[4,1], where each row is an index @@ -305,8 +303,8 @@ StatusOr GatherExpander::ExpandGather( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); - HloInstruction* gather_indices = gather_instr->mutable_operand(1); - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); @@ -314,9 +312,9 @@ StatusOr GatherExpander::ExpandGather( gather_instr->gather_dimension_numbers(); int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= gather_indices_shape.dimensions(i); + gather_loop_trip_count *= start_indices_shape.dimensions(i); } } @@ -324,27 +322,27 @@ StatusOr GatherExpander::ExpandGather( return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " "supported. This error occurred for %s.", - gather_instr->ToString().c_str()); + gather_instr->ToString()); } - TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, - CanonicalizeGatherIndices( - gather_indices, dim_numbers.index_vector_dim())); + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_start_indices, + CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim())); CHECK_EQ(gather_loop_trip_count, - canonical_gather_indices->shape().dimensions(0)); + canonical_start_indices->shape().dimensions(0)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_init, CreateGatherLoopAccumulatorInitValue( computation, output_shape.element_type(), - gather_instr->gather_window_bounds(), gather_loop_trip_count, + gather_instr->gather_slice_sizes(), gather_loop_trip_count, gather_instr->gather_dimension_numbers())); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( computation, gather_loop_trip_count, - {operand, canonical_gather_indices, accumulator_init}, + {operand, canonical_start_indices, accumulator_init}, [&](HloInstruction* indvar, const std::vector& loop_state) { return GatherLoopBody(*gather_instr, indvar, loop_state); @@ -356,13 +354,13 @@ StatusOr GatherExpander::ExpandGather( HloInstruction* accumulator_result = gather_loop_result.back(); TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result, - dim_numbers.index_vector_dim())); + HloInstruction* const accumulator_with_batch_dims_decanonicalized, + AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result, + dim_numbers.index_vector_dim())); - return PermuteGatherAndWindowDims( - accumulator_with_output_gather_dims_decanonicalized, - AsInt64Slice(dim_numbers.output_window_dims()), output_rank); + return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized, + AsInt64Slice(dim_numbers.offset_dims()), + output_rank); } StatusOr GatherExpander::Run(HloModule* module) { @@ -375,8 +373,8 @@ StatusOr GatherExpander::Run(HloModule* module) { std::vector gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - is_nontrivial_gather); + absl::c_copy_if(computation->instructions(), + std::back_inserter(gather_instrs), is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index c1fc8574da99fff223c7dbb570b4533f76905b9a..2b39359aae9fc01f1a88a2594108b2772788e826 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -23,9 +23,9 @@ namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloPassInterface { +class GatherExpander : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "gather_expander"; } + absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 020ffcd106862cb2641a9f3bceb70acdd969a458..141dd4d6f10272ce749edc4e91153c365ed322e6 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -28,11 +28,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2147483647,5] parameter(1) ROOT gather = s32[2147483647,3,5] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -55,11 +55,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index e314a469f00abdb9f60ae812c0b78d273dc95dbe..bec02e14f951c6d905b7329be5c02896984279d0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -43,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -60,17 +58,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) { + MutableBorrowingLiteral literal, std::function done) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); } - done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); + + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer, + literal)); } -StatusOr> -GenericTransferManager::TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { +Status GenericTransferManager::TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -80,9 +80,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), device_buffer.on_host_shape())); - std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { @@ -91,12 +88,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->untyped_data(index))); + literal.untyped_data(index))); } return Status::OK(); })); - return std::move(literal); + return Status::OK(); } Status GenericTransferManager::TransferLiteralToDeviceAsync( @@ -128,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_memory.size()); // Element is array-shaped: transfer array data to device buffer. const auto subliteral = LiteralSlice(literal, index); - std::unique_ptr relayed_out_literal; + Literal relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { @@ -141,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->untyped_data(); + source = relayed_out_literal.untyped_data(); TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, /*size=*/GetByteSizeRequirement(device_subshape), source, @@ -160,12 +157,12 @@ Status GenericTransferManager::TransferLiteralToInfeed( Status GenericTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { return Unimplemented("Generic transfer from Outfeed"); } Status GenericTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice + absl::Span /*executors*/) { return Unimplemented( "Device reset is not yet supported on this platform (b/30481585)"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3cd002c1bf3555cc2d2891c88b3ad648f8d9fd8c..86c8b1c145a25149a25e7b272babc5c858d476af 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) override; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, @@ -53,22 +53,21 @@ class GenericTransferManager : public TransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; - Status ResetDevices( - tensorflow::gtl::ArraySlice executors) override; + Status ResetDevices(absl::Span executors) override; int64 GetByteSizeRequirement(const Shape& shape) const override; protected: Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: - StatusOr> TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal); // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4947dd278e9e70c8a1c26b0d7d62f97221c33750..51968d13d492d6cb1d9731c9c18c7c8e4962c0d5 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") licenses(["notice"]) # Apache 2.0 @@ -55,6 +56,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -65,9 +68,7 @@ cc_library( # srcs = [ # "partition_assignment_test.cc", # ], -# tags = [ -# "requires-gpu-sm35", -# ], +# tags = tf_cuda_tests_tags(), # deps = [ # ":partition_assignment", # "//tensorflow/core:stream_executor_no_cuda", @@ -90,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -104,8 +106,12 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -125,6 +131,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -153,7 +161,6 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":while_transformer", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -165,11 +172,14 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -179,6 +189,12 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -223,6 +239,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -242,6 +260,8 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -256,6 +276,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -336,6 +357,11 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -344,6 +370,7 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ + ":backend_configs", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -361,15 +388,21 @@ cc_library( hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ ":backend_configs", + ":buffer_comparator", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], ) @@ -378,6 +411,8 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":backend_configs", + ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -386,7 +421,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -395,6 +433,7 @@ cc_library( srcs = ["cudnn_convolution_rewriter.cc"], hdrs = ["cudnn_convolution_rewriter.h"], deps = [ + ":backend_configs", ":ir_emission_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -417,7 +456,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -428,6 +467,7 @@ cc_library( srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":gpu_fusible", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -448,6 +488,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -457,12 +498,14 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":gpu_fusible", ":instruction_fusion", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -480,6 +523,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -503,6 +547,7 @@ cc_library( srcs = ["fusion_merger.cc"], hdrs = ["fusion_merger.h"], deps = [ + ":gpu_fusible", ":instruction_fusion", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -510,6 +555,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -541,6 +588,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", + "@com_google_absl//absl/memory", ], ) @@ -550,14 +598,11 @@ cc_library( hdrs = ["pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:shape_inference", ], ) @@ -597,6 +642,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains per-platform transfer manager registration @@ -609,13 +655,14 @@ cc_library( deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", + ":cudnn_fused_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_schedule", ":gpu_hlo_support_checker", ":gpu_layout_assignment", - ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -652,10 +699,10 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", - "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -666,6 +713,10 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", ], alwayslink = True, # Contains compiler registration @@ -698,8 +749,8 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -714,6 +765,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -731,6 +783,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -752,54 +805,44 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", ], ) cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], + name = "gpu_hlo_schedule", + srcs = ["gpu_hlo_schedule.cc"], + hdrs = ["gpu_hlo_schedule.h"], deps = [ ":stream_assignment", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "@com_google_absl//absl/memory", ], ) tf_cc_test( - name = "hlo_schedule_test", + name = "gpu_hlo_schedule_test", srcs = [ - "hlo_schedule_test.cc", + "gpu_hlo_schedule_test.cc", ], deps = [ - ":hlo_schedule", + ":gpu_hlo_schedule", ":stream_assignment", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - -cc_library( - name = "while_transformer", - srcs = ["while_transformer.cc"], - hdrs = ["while_transformer.h"], - deps = [ - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -808,12 +851,12 @@ tf_cc_test( srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", - ":while_transformer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -850,7 +893,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -863,8 +908,79 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + ":gpu_executable", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":buffer_comparator", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "gpu_fusible", + srcs = ["gpu_fusible.cc"], + hdrs = ["gpu_fusible.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:hlo", + ], +) + +tf_cc_test( + name = "gpu_fusible_test", + srcs = ["gpu_fusible_test.cc"], + deps = [ + ":gpu_fusible", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cudnn_fused_convolution_rewriter", + srcs = ["cudnn_fused_convolution_rewriter.cc"], + hdrs = ["cudnn_fused_convolution_rewriter.h"], + deps = [ + ":backend_configs", + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:stream_executor_no_cuda", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 640c6392b8b820c708b853c2a3cea4d4116e85a8..78e14d860e31ace2fcb3f51fb8e0c40a0ea5f3dd 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -24,4 +24,18 @@ message CudnnConvBackendConfig { // true, cudnn may choose not to use tensor cores, e.g. because the GPU or // selected algorithm doesn't support it. bool tensor_ops_enabled = 2; + + // The scaling factor multiplied with the convolution result. + double conv_result_scale = 4; + + // Below are the fields related to cuDNN's fused convolution. Refer to + // CudnnConvParams for their meanings. + + // The requested activation (e.g. relu) after the convolution. It is with type + // stream_executor::dnn::ActivationMode. + int64 activation_mode = 3; + + // The scaling factor multiplied with the side input. If no side input buffer + // is provided, this field must be 0. + double side_input_scale = 5; } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 537295292b6ced72c4b2c456557b3c06e0aa5254..528209abc75777440163c2e1512658b8ad36315b 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -40,7 +40,7 @@ StatusOr> BufferAllocations::Builder::Build( const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { const int64 num_buffers = buffer_assignment->Allocations().size(); - auto buffer_allocations = WrapUnique(new BufferAllocations( + auto buffer_allocations = absl::WrapUnique(new BufferAllocations( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { @@ -62,7 +62,7 @@ StatusOr> BufferAllocations::Builder::Build( if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( - "Address of registered buffer %lld must be a multiple of %llx, but " + "Address of registered buffer %d must be a multiple of %x, but " "was %p", i, kEntryParameterAlignBytes, address.opaque()); } @@ -83,7 +83,7 @@ StatusOr> BufferAllocations::Builder::Build( 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " - "multiple of %llx, but was %p", + "multiple of 0x%x, but was %p", kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index f13eab0dd787a2bfa687c991f9d808568360fd24..14186b8faa68ad8492ea4863fcd7bd746e2eae48 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc new file mode 100644 index 0000000000000000000000000000000000000000..13c83c9199fb1bbd8b00dbd601afcb677f92bbee --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -0,0 +1,204 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace gpu { + +static constexpr float kTolerance = 0.1f; + +static string GetCompHloText(size_t num_elements) { + // Implements the textual format of the comparison routine, as it's more + // readable. + static constexpr char kF16CompHloText[] = R"( +HloModule CompareF16 + +MaxF32 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %max = f32[] maximum(%lhs, %rhs) +} + +Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { + %min_constant = f32[] constant(-65505) + %max_constant = f32[] constant(65505) + %large_constant = f32[] constant(1048576) + %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} + %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} + %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + + %a = f16[SIZE] parameter(0) + %converted = f32[SIZE] convert(%a) + %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) + + // Since the clamp() above already took care of infs, only NaNs will cause + // is-finite() to return false. + %is_finite = pred[SIZE] is-finite(%clamped) + ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) +} + +ENTRY MaxDifference { + %one_constant = f32[] constant(1.0) + %zero_constant = f32[] constant(0.0) + + %ones = f32[SIZE] broadcast(%one_constant), dimensions={} + + %lhs = f16[SIZE] parameter(0) + %rhs = f16[SIZE] parameter(1) + %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize + %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize + %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) + %sub_abs = f32[SIZE] abs(%sub) + %lhs_abs = f32[SIZE] abs(%lhs_canonical) + %rhs_abs = f32[SIZE] abs(%rhs_canonical) + %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) + %denominator = f32[SIZE] add(%max, %ones) + %error = f32[SIZE] divide(%sub_abs, %denominator) + ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 +})"; + return absl::StrReplaceAll(kF16CompHloText, + {{"SIZE", absl::StrCat(num_elements)}}); +} + +StatusOr F16BufferComparator::Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream) { + auto stream_exec = stream->parent(); + int64 num_elements = ref_buffer.ElementCount(); + + // One may consider using hlo_runner to do all the compilation and execution. + // However, as of the time hlo_runner doesn't support injection for Compiler*, + // Stream*, or even the allocator. We may revisit this in the future if it + // proves to be a maintenance burden. + TF_ASSIGN_OR_RETURN( + auto exec, ([&]() -> StatusOr> { + HloModuleConfig config; + DebugOptions debug_options; + debug_options.set_xla_backend_optimization_level(2); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN( + auto module, ParseHloString(GetCompHloText(num_elements), config)); + TF_ASSIGN_OR_RETURN( + module, + compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); + return compiler->RunBackend(std::move(module), stream_exec, nullptr); + }())); + + TF_ASSIGN_OR_RETURN( + auto shaped_buffer, ([&]() -> StatusOr { + auto device_ordinal = stream_exec->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto owning_buffer, + allocator->Allocate(device_ordinal, ref_buffer.size())); + se::DeviceMemory buffer( + owning_buffer.AsDeviceMemoryBase()); + stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); + ret.set_buffer(std::move(owning_buffer), {}); + return std::move(ret); + }())); + + return F16BufferComparator(stream, allocator, std::move(exec), + std::move(shaped_buffer)); +} + +StatusOr F16BufferComparator::CompareEqualImpl( + se::DeviceMemory test_buffer) { + if (ref_buffer_.root_buffer().size() != test_buffer.size()) { + return InternalError("Mismatched buffer size: %d vs %d", + ref_buffer_.root_buffer().size(), test_buffer.size()); + } + + int64 num_elements = test_buffer.ElementCount(); + + TF_ASSIGN_OR_RETURN( + auto result_buffer, ([&]() -> StatusOr { + auto stream_exec = stream_->parent(); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + auto device_ordinal = stream_exec->device_ordinal(); + ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), + device_ordinal); + shaped_test_buffer.set_buffer(test_buffer, {}); + ExecutableRunOptions run_options; + run_options.set_device_ordinal(stream_exec->device_ordinal()); + run_options.set_stream(stream_); + run_options.set_allocator(allocator_); + ServiceExecutableRunOptions service_run_options(run_options); + return exec_->ExecuteOnStream( + &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); + }())); + + float result; + CHECK(result_buffer.root_buffer().size() == sizeof(result)); + stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + return result < kTolerance; +} + +StatusOr F16BufferComparator::CompareEqual( + se::DeviceMemory test_buffer) { + TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); + if (result) { + return true; + } + // Host side code that does the same thing, but report some of the + // differences as well. + int64 n = test_buffer.ElementCount(); + std::vector host_ref_buffer(n), host_test_buffer(n); + stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), + ref_buffer_.root_buffer().size()); + stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + + const auto canonicalize = [](float a) -> float { + constexpr float kBigNumer = 1048576.; + constexpr float kMaxFp16Value = 65504.; + if (std::isnan(a)) { + return kBigNumer; + } + if (std::isinf(a)) { + if (a < 0) { + return -(kMaxFp16Value + 1); + } + return kMaxFp16Value + 1; + } + return a; + }; + int differences_seen = 0; + for (int64 i = 0; i < n && differences_seen < 10; i++) { + float original_ref = static_cast(host_ref_buffer[i]); + float original_test = static_cast(host_test_buffer[i]); + float ref = canonicalize(original_ref); + float test = canonicalize(original_test); + if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + kTolerance)) { + differences_seen++; + LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " + << original_test; + } + } + + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h new file mode 100644 index 0000000000000000000000000000000000000000..bf2ba78ceacaea1070830f758c3712b1378bd96f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A fp16 comparator that internally keeps a reference buffer, and compares it +// against other test buffers. +class F16BufferComparator { + public: + F16BufferComparator(const F16BufferComparator&) = delete; + F16BufferComparator(F16BufferComparator&&) = default; + + // Creates a new comparator. It internally allocates a buffer initialized by + // ref_buffer. + static StatusOr Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream); + + // Returns true if the internally allocated buffer "compares equal" to + // test_buffer. The definition of "equal" is: + // * All NaNs equal. + // * All infs are treated as 65505 or -65505, so that this checker is tolerant + // to fp16 overflows. + // * With NaNs and infs taken care of, a and b compare equal iff: + // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance + // + // See the implementation for the tolerance value. + StatusOr CompareEqual(se::DeviceMemory test_buffer); + + private: + F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, + std::unique_ptr exec, + ScopedShapedBuffer ref_buffer) + : stream_(stream), + allocator_(allocator), + exec_(std::move(exec)), + ref_buffer_(std::move(ref_buffer)) {} + + StatusOr CompareEqualImpl(se::DeviceMemory test_buffer); + + se::Stream* stream_; + DeviceMemoryAllocator* allocator_; + std::unique_ptr exec_; + ScopedShapedBuffer ref_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..33761d1bd8807df225e2cf505303b120e418576f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class BufferComparatorTest : public testing::Test { + protected: + BufferComparatorTest() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), + stream_exec_(backend_->default_stream_executor()), + allocator_(stream_exec_->platform(), {stream_exec_}), + compiler_(Compiler::GetForPlatform(stream_exec_->platform()) + .ConsumeValueOrDie()) {} + + // Take floats only for convenience. Still uses half internally. + bool CompareEqualFloatBuffers(const std::vector& lhs_float, + const std::vector& rhs_float) { + std::vector lhs(lhs_float.begin(), lhs_float.end()); + std::vector rhs(rhs_float.begin(), rhs_float.end()); + se::Stream stream(stream_exec_); + stream.Init(); + + auto owning_lhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto owning_rhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto lhs_buffer = + se::DeviceMemory(owning_lhs_buffer.AsDeviceMemoryBase()); + auto rhs_buffer = + se::DeviceMemory(owning_rhs_buffer.AsDeviceMemoryBase()); + + stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); + stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + + TF_CHECK_OK(stream.BlockHostUntilDone()); + + return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, + &stream) + .ConsumeValueOrDie() + .CompareEqual(rhs_buffer) + .ConsumeValueOrDie(); + } + + std::unique_ptr backend_; + se::StreamExecutor* stream_exec_; + StreamExecutorMemoryAllocator allocator_; + Compiler* compiler_; +}; + +TEST_F(BufferComparatorTest, TestNaNs) { + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); +} + +TEST_F(BufferComparatorTest, TestInfs) { + const auto inf = std::numeric_limits::infinity(); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +} + +TEST_F(BufferComparatorTest, TestNumbers) { + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); +} + +TEST_F(BufferComparatorTest, TestMultiple) { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 5780e0af40699bb6ac2c190c09cd02023fb44db7..9ed523998bf07567133fdac0e40b12b8ce4ea3b0 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to retrieve predicate value on stream %p: %s.", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } // Execute the true or the false computation depending on the value of the diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 7833a4077e6c6ee4960665f37fb01a35530fd302..4effea637d01bf23b54d341b77306b20b1b133c8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,78 +17,50 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -using se::dnn::AlgorithmDesc; - ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} + const HloCustomCallInstruction* cudnn_call, + std::vector operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + operand_buffers_(std::move(operand_slices)), + result_buffer_(result_slice), + scratch_buffer_(scratch_slice), + tuple_result_buffer_(tuple_result_slice) {} Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + std::vector operand_se_buffers; + for (const auto& buffer : operand_buffers_) { + operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); + } + + se::DeviceMemoryBase result_buffer = + buffer_allocations.GetDeviceAddress(result_buffer_); + se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, - stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_, + absl::MakeSpan(operand_se_buffers), + result_buffer, scratch, stream)); - // Figure out which of output/input/filter is the result produced by - // this op, and write the result tuple. - void* result_ptr = [&] { - switch (convolution_kind_) { - case CudnnConvKind::kForward: - return output_data.opaque(); - case CudnnConvKind::kBackwardInput: - return input_data.opaque(); - case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); - } - }(); - void* ptrs[] = {result_ptr, scratch.opaque()}; + void* ptrs[] = {result_buffer.opaque(), scratch.opaque()}; se::DeviceMemory tuple_addr( buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); stream->ThenMemcpyH2D(ptrs, &tuple_addr); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index d76ca6698dcf462c3c4961ce6a9784822af3a81f..f53bc541983378819dba36489dd69c348f50af32 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" @@ -23,16 +24,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,26 +42,12 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // - // Note that "output" here doesn't refer to the output from running this - // thunk, but rather to the "output" of a hypothetical forward convolution - // that corresponds to this input+filter+output triple. That is, the result - // generated by this thunk is "output" for forward convs, "input" for - // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + // operand_slices should be in the same order as cudnn_call->operands(). + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + std::vector operand_slices, + BufferAllocation::Slice result_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -71,35 +58,11 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - class ScratchAllocator; - - Status Convolve(const se::dnn::BatchDescriptor& input_descriptor, - se::DeviceMemory input_data, - const se::dnn::FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const se::dnn::BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const se::dnn::ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, - se::Stream* stream, ScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result); - - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + std::vector operand_buffers_; + BufferAllocation::Slice result_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index e09cde9abf85454c7a020566cd8c2671ae12ffc3..c3f58508ddd4451312325b0d440473515812dac9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -52,11 +52,9 @@ namespace gpu { // The GPU backend does not implement a lowering for the batchnorm HLOs -- it // expects them to be lowered to cudnn calls via this pass or to HLO soup via // BatchNormRewriter. -class CudnnBatchNormRewriter : public HloPassInterface { +class CudnnBatchNormRewriter : public HloModulePass { public: - tensorflow::StringPiece name() const override { - return "cudnn_batchnorm_rewriter"; - } + absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 7b172812c36bb141787ef3a9285d6f7ce13e343b..bc3c6f72f6799f84169748465d62c3f2a306d5fc 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 7348307ec8a7286dfb733d6b9685862b20f11ac9..7125673887d28729287d67577bcfa06423f85611 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,24 +14,26 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { namespace { +using absl::optional; using se::DeviceMemoryBase; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; class ScratchAllocator : public se::ScratchAllocator { public: @@ -59,8 +61,8 @@ StatusOr> ScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -74,68 +76,38 @@ StatusOr> ScratchAllocator::AllocateBytes( return se::DeviceMemory(buffer_addr); } -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output shapes. This works around b/68264959, an integer -// overflow in cuDNNv5 and cuDNNv6. -bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, - const Shape& output_shape, - const ConvolutionDimensionNumbers& dnums, - se::StreamExecutor* stream_exec) { - // Skip this check for cudnn7 and newer. - auto version = stream_exec->AsDnn()->GetVersion(); - if (version.ok() && version.ValueOrDie().major_version() >= 7) { - return true; - } - - int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); - int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); - int64 in_cols = - dnums.input_spatial_dimensions_size() == 1 - ? 1 - : input_shape.dimensions(dnums.input_spatial_dimensions(1)); - int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); - - int64 total_size = CeilOfRatio(batch, int64{16}) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - - const int64 threshold = 1L << 31; - return total_size < threshold; -} - std::vector GetAlgorithms(CudnnConvKind kind, - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) { std::vector algorithms; + bool succ = false; switch (kind) { case CudnnConvKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = + stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms); break; case CudnnConvKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); break; case CudnnConvKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); + case CudnnConvKind::kForwardActivation: + succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); break; } + DCHECK(succ); return algorithms; } string AlgorithmToString(const AlgorithmDesc& algo) { if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + return absl::StrCat(algo.algo_id(), "+TC"); } - return tensorflow::strings::StrCat(algo.algo_id()); + return absl::StrCat(algo.algo_id()); } string NumBytesToString(int64 bytes) { - return tensorflow::strings::StrCat( - tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); + return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", + bytes, "B)"); } // Acquires a process-global lock on the device pointed to by the given @@ -173,11 +145,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -optional> +StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + HloCustomCallInstruction* instr) { + // TODO(timshen): for now only check fp16. It can be expanded to other types, + // with some work on the HLO routines. + const bool cross_check_enabled = + instr->shape().tuple_shapes(0).element_type() == xla::F16; + // Don't run this function concurrently on the same GPU. // // This is a bit of a hack and doesn't protect us against arbitrary concurrent @@ -185,6 +160,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // concurrently and then run them sequentially. tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Make sure any previous activity on this executor is done. We don't want to + // interfere with programs that are still running on the GPU. + if (!stream_exec_->SynchronizeAllActivity()) { + return InternalError("Failed to synchronize GPU for autotuning."); + } + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); @@ -197,75 +178,123 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( if (allocator_ != nullptr) { allocator = allocator_; } else { - se_allocator.emplace( - stream_exec_->platform(), - tensorflow::gtl::ArraySlice({stream_exec_})); + se_allocator.emplace(stream_exec_->platform(), + absl::Span({stream_exec_})); allocator = &*se_allocator; } + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. + CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); + size_t left_over_bytes = buffer.size() % 4; + CHECK_EQ(0, left_over_bytes % 2); + + constexpr float kBroadcastedConstant = 0.1f; + static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), + Eigen::half(kBroadcastedConstant)}; + uint32 bits; + static_assert(sizeof(bits) == sizeof(halfs), ""); + memcpy(&bits, halfs, sizeof(bits)); + + size_t aligned_size = buffer.size() / 4 * 4; + stream.ThenMemset32(&buffer, bits, aligned_size); + + DeviceMemoryBase left_over( + static_cast(buffer.opaque()) + aligned_size, left_over_bytes); + stream.ThenMemcpy(&left_over, halfs, left_over_bytes); + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); + } + }; + // Allocate space for the input, filter, and output of the convolution. We // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. - // - // We don't put any data in these buffers, because (in theory, anyway) the - // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - StatusOr maybe_input_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(input_shape)); - StatusOr maybe_filter_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(filter_shape)); - StatusOr maybe_output_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(output_shape)); - if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || - !maybe_output_buf.ok()) { - LOG(WARNING) - << "Couldn't allocate space for input/filter/output of convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; + std::vector operand_buffers; + for (const auto* operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(auto buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + initialize_buffer(buffer); + operand_buffers.push_back(buffer); } + TF_ASSIGN_OR_RETURN( + auto result_buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); + initialize_buffer(result_buffer); - DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); - DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); - DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); - - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the inputs - // might affect performance if e.g. the inputs contain denormals, and this is - // easy enough. - if (!stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()) - .BlockHostUntilDone() - .ok()) { - LOG(WARNING) - << "Couldn't zero out input/filter/output buffer for convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } - - const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; - - for (const AlgorithmDesc& alg : - GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config()); + + optional comparator; + // Use the first algorithm that's supported as reference. There isn't a + // particular reason to use it, as any algorithm sufficies. It doesn't make + // this algorithm considered correct, though. + optional first_algorithm; + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = - RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, - AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + backend_config.set_algorithm(alg.algo_id()); + backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); + TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers), + result_buffer, &scratch_allocator, + &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + if (comparator.has_value()) { + StatusOr result = comparator->CompareEqual( + se::DeviceMemory(result_buffer)); + if (!result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << result.status(); + CHECK(!crash_on_checking_failure); + } else if (!result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + } + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(result_buffer), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " + << comp.status() << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); + } + } int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " succeeded, taking " << profile_result.elapsed_time_in_ms() @@ -292,41 +321,21 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( best_result_bytes_used); } - LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() - << " failed. Falling back to default algorithm."; - return nullopt; + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString()); } StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - const auto& call_target = instr->custom_call_target(); - const auto& lhs_shape = instr->operand(0)->shape(); - const auto& rhs_shape = instr->operand(1)->shape(); - const auto& conv_result_shape = instr->shape().tuple_shapes(0); - optional> alg_scratch_and_tc; - if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); - } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr); - } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instr->ToString(); - } + StatusOr> alg_scratch_and_tc = + PickBestAlgorithm(Cast(instr)); - if (!alg_scratch_and_tc.has_value()) { + if (!alg_scratch_and_tc.ok()) { + LOG(ERROR) << alg_scratch_and_tc.status(); return false; } @@ -334,7 +343,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( bool tensor_ops_enabled; int64 scratch_bytes; - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = + alg_scratch_and_tc.ConsumeValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " << NumBytesToString(scratch_bytes) @@ -348,18 +358,14 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - CudnnConvBackendConfig backend_config; + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + instr->backend_config()); backend_config.set_algorithm(algorithm); backend_config.set_tensor_ops_enabled(tensor_ops_enabled); - HloInstruction* new_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - instr->custom_call_target())); - new_call->set_window(instr->window()); - new_call->set_convolution_dimension_numbers( - instr->convolution_dimension_numbers()); + HloInstruction* new_call = computation->AddInstruction( + instr->CloneWithNewOperands(new_call_shape, instr->operands())); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index bc5d1ce94afd2075a006899f0f6bcf64352e5e99..aeda2fc7f8b4d6169fc2baa8975119ba7bf68dd2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -28,16 +30,17 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloPassInterface { +class CudnnConvolutionAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} + DeviceMemoryAllocator* allocator, + Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-algorithm-picker"; } @@ -46,13 +49,12 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + StatusOr> PickBestAlgorithm( + HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 905b5ee8767d0fa0514c7f1abf83bc089cd08045..ef292373018295f5100b91c343df274b626c2fa1 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include #include #include #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -34,6 +36,32 @@ namespace gpu { namespace { +HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); + return custom_call; +} + bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -59,6 +87,9 @@ std::tuple MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + if (conv->feature_group_count() > 1) { + return no_match_result; + } // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -213,42 +244,55 @@ std::tuple MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardInput( - HloInstruction* conv) { +std::tuple +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); + + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. + if (conv->feature_group_count() > 1) { + return no_match_result; + } // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -373,23 +417,70 @@ std::tuple MatchBackwardInput( } } - // Fuse the matched HLOs into a backward convolution instruction. - // - // If the reverse is omitted (for 1x1 filters) in the original pattern, we add - // it back in the fusion instruction so that later passes (such as - // PadInsertion) can handle such fusion instructions easily. - if (reverse_filter->opcode() != HloOpcode::kReverse) { - reverse_filter = reverse_filter->parent()->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); - } + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. dnums.set_kernel_input_feature_dimension( conv->convolution_dimension_numbers().kernel_output_feature_dimension()); dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { + // Create a double-reverse, which is a nop. + HloComputation* c = conv->parent(); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); + } + + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); +} + +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; } // Tries to rewrite a single convolution into a call to cudnn. @@ -400,30 +491,28 @@ StatusOr RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - return CreateCudnnConvBackwardFilter( - conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums); + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - - return CreateCudnnConvBackwardInput( - conv->shape(), conv->mutable_operand(0), rhs, window, dnums); + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers()); + return CreateCudnnConv( + kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers(), conv->feature_group_count()); } return nullptr; @@ -433,6 +522,9 @@ StatusOr RunOnInstruction(HloInstruction* conv) { return false; } + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index 0c0578d88840fed1d77f7456c9acef27dec380f5..8d7c6fdab510407428a115579a90e8cf85e9fad2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -24,9 +24,9 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloPassInterface { +class CudnnConvolutionRewriter : public HloModulePass { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-rewriter"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 65588b6aaf24da628ea586eb52c462b78b8daaa7..d237f8930b74d460ad3d4602670a5afb19b496a2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,10 +32,13 @@ namespace gpu { namespace { namespace op = xla::testing::opcode_matchers; +using ::testing::_; -class CudnnConvolutionRewriterTest : public HloTestBase { +class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { public: - CudnnConvolutionRewriterTest() { + CudnnConvolutionRewriterTest() + : HloVerifiedTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -104,17 +107,17 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { conv_window.mutable_dimensions(1)->set_size(2); conv_window.mutable_dimensions(1)->set_window_dilation(2); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -132,17 +135,17 @@ TEST_F(CudnnConvolutionRewriterTest, Window conv_window = default_conv_window_; conv_window.mutable_dimensions(1)->set_size(3); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -167,12 +170,13 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -197,12 +201,13 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -225,12 +230,13 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -269,18 +275,19 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, conv_window, conv_dnums)); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, + conv_dnums, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, conv_window, conv_dnums) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -316,16 +323,16 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - conv_window, + /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, conv_window, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -347,17 +354,18 @@ TEST_F(CudnnConvolutionRewriterTest, 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - default_conv_window_, - tf_default_dnums_for_backward_input_) + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -399,18 +407,20 @@ TEST_F(CudnnConvolutionRewriterTest, } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -446,18 +456,20 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -499,18 +511,20 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -551,23 +565,51 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +// Check that we will materialize a reversed version of a constant in order to +// pattern-match a backwards input convolution. +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { + Array4D constant_arr(4, 4, 2, 2); + constant_arr.FillIota(0); + string constant_str = + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + ParseAndVerifyModule(absl::StrFormat(R"( + HloModule test + + ENTRY entry_computation { + param0 = f32[128,2,16,16]{3,2,1,0} parameter(0) + constant = f32[4,4,2,2]{3,2,1,0} constant(%s) + ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant), + window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, + dim_labels=bf01_01oi->bf01, feature_group_count=1 + })", + constant_str)); + EXPECT_TRUE(RunPass(&module())); + EXPECT_THAT( + module().entry_computation()->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 0645fbb3ad39f1f1649caf45a6068b5a196c30b9..89dd1bb272663ac1f6eecbaae070d201d38e44c8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -36,6 +39,42 @@ using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; using se::dnn::ProfileResult; +struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional fusion; +}; + // A StreamExecutor ScratchAllocator that wraps a single XLA allocation, // returning it (in its entirety) the first time Allocate() is called. class ScratchBufAllocator : public se::ScratchAllocator { @@ -56,7 +95,7 @@ class ScratchBufAllocator : public se::ScratchAllocator { "Can't allocate twice from a ScratchBufAllocator."); } if (byte_size > scratch_.size()) { - return se::port::InternalError(tensorflow::strings::StrCat( + return se::port::InternalError(absl::StrCat( "Can't allocate ", byte_size, " bytes from a ScratchBufAllocator of size ", scratch_.size())); } @@ -71,20 +110,29 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory input_buf, - DeviceMemory filter_buf, DeviceMemory output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, - Stream* stream, ProfileResult* profile_result /*= nullptr*/) { +Status RunCudnnConvolutionImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory input_buf(params.input_buf); + DeviceMemory filter_buf(params.filter_buf); + DeviceMemory output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); - VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; - VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; - VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); + VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); + VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; @@ -96,15 +144,9 @@ Status RunCudnnConvolution( // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); - if (std::is_same::value) { - CHECK_EQ(F32, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else if (std::is_same::value) { - CHECK_EQ(F16, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else { - LOG(FATAL) << ShapeUtil::HumanString(output_shape); - } + CHECK_EQ(primitive_util::NativeToPrimitiveType(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); @@ -149,6 +191,7 @@ Status RunCudnnConvolution( } ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + convolution_descriptor.set_group_count(feature_group_count); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -181,86 +224,195 @@ Status RunCudnnConvolution( switch (kind) { case CudnnConvKind::kForward: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveWithAlgorithm( input_descriptor, input_buf, filter_descriptor, filter_buf, convolution_descriptor, output_descriptor, &output_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardInput: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardDataWithAlgorithm( filter_descriptor, filter_buf, output_descriptor, output_buf, convolution_descriptor, input_descriptor, &input_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardFilter: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardFilterWithAlgorithm( input_descriptor, input_buf, output_descriptor, output_buf, convolution_descriptor, filter_descriptor, &filter_buf, scratch_allocator, algorithm, profile_result); break; + case CudnnConvKind::kForwardActivation: { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_layout(output_dl); + + se::DeviceMemory side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; + } + + stream->ThenFusedConvolveWithAlgorithm( + input_descriptor, input_buf, params.conv_result_scale, + filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), params.fusion->mode, + output_descriptor, &output_buf, scratch_allocator, algorithm, + profile_result); + break; + } } if (!stream->ok()) { return InternalError( - "Unable to launch convolution with type %s and algorithm (%lld, %lld)", - CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + "Unable to launch convolution with type %s and algorithm (%d, %d)", + CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), algorithm.algorithm_no_scratch().algo_id()); } return Status::OK(); } -} // anonymous namespace +// Returns the cudnn convolution parameters generated from conv, which must be a +// custom-call to a cudnn convolution. +StatusOr GetCudnnConvParams( + const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer) { + CudnnConvParams params; -string CudnnConvKindToString(CudnnConvKind kind) { - switch (kind) { - case CudnnConvKind::kForward: - return "forward"; - case CudnnConvKind::kBackwardFilter: - return "backward_filter"; - case CudnnConvKind::kBackwardInput: - return "backward_input"; + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + conv->backend_config()); + const auto& target = conv->custom_call_target(); + const auto& lhs_shape = conv->operand(0)->shape(); + const auto& rhs_shape = conv->operand(1)->shape(); + const auto& conv_result_shape = conv->shape().tuple_shapes(0); + + params.window = &conv->window(); + params.dnums = &conv->convolution_dimension_numbers(); + params.feature_group_count = conv->feature_group_count(); + params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + params.conv_result_scale = backend_config.conv_result_scale(); + + if (target == kCudnnConvForwardCallTarget) { + params.kind = CudnnConvKind::kForward; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + } else if (target == kCudnnConvBackwardInputCallTarget) { + params.kind = CudnnConvKind::kBackwardInput; + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + } else if (target == kCudnnConvBackwardFilterCallTarget) { + params.kind = CudnnConvKind::kBackwardFilter; + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + } else if (target == kCudnnConvBiasActivationForwardCallTarget) { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } + } else { + return InternalError("Unexpected custom call target: %s", target); } + return params; } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +} // anonymous namespace + +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution(conv, operand_buffers, result_buffer, + &scratch_allocator, stream, profile_result); } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); - CHECK(output_primitive_type == F32 || output_primitive_type == F16) - << ShapeUtil::HumanString(output_shape); - if (output_primitive_type == F32) { - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, dnums, - algorithm, stream, profile_result); +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + TF_ASSIGN_OR_RETURN(CudnnConvParams params, + GetCudnnConvParams(conv, operand_buffers, result_buffer)); + + PrimitiveType output_primitive_type = + conv->shape().tuple_shapes(0).element_type(); + switch (output_primitive_type) { + case F16: + return RunCudnnConvolutionImpl(params, scratch_allocator, + stream, profile_result); + case F32: + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); + case F64: + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 944e4ac686d45408b08ff1faa321510c1c8920ba..61aec1ceccec0f253f9ddaa688d64cacea800cf3 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -27,39 +30,8 @@ namespace gpu { // This file contains low-level routines for running cudnn convolutions. -// Different types of convolutions supported by cudnn. -// -// A way to think about these is that a convolution is defined by three arrays -// -- the "input", the "filter", and the "output" -- and given any two of these, -// we can compute the third. For example, a backward-input convolution takes as -// input a filter and an "output" and produces an "input" such that if one were -// to do a forward convolution of "input" using filter, the result would be -// something with the same shape as "output". -// -// This way of thinking is not correct if you look at the values produced. For -// example, a backward-input convolution is not actually the mathematical -// inverse of a forward convolution. But it's right as far as the shapes and -// "connectivity" (i.e. which elements of the input affect which elements of -// the output) are concerned. -enum class CudnnConvKind { - kForward, // input + filter => output - kBackwardInput, // filter + output => input - kBackwardFilter, // input + output => filter -}; - -// Converts a CudnnConvKind value to a string. -string CudnnConvKindToString(CudnnConvKind kind); - // Calls into cudnn to run the specified convolution. // -// Note that depending on the value of CudnnConvKind, the result of this call -// may be written into input_buf, filter_buf, or output_buf! -// -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. -// // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In // theory the second one shouldn't be necessary -- users of this function could @@ -70,23 +42,18 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); - -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..3761c19cfcab10e0c6faa17c2d1d535d706ff6c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc @@ -0,0 +1,278 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { +namespace { + +// Describes a matched pattern: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// Where side_input has the shape of output buffer, and bias is a 1D array with +// the dimension of number of output features. +struct ConvWithRelu { + HloInstruction* maximum; + HloCustomCallInstruction* conv; + HloInstruction* bias; + HloInstruction* side_input; + HloConstantInstruction* alpha_conv; + HloConstantInstruction* alpha_side_input; +}; + +absl::optional FindConvWithRelu(HloInstruction* instr) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Broadcast; + using match::Constant; + using match::GetTupleElement; + using match::Maximum; + using match::MultiplyAnyOrder; + using match::Op; + + // The pattern we want to match: + // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); + // + // With its variants involving commute/reassociation of adds, multiplies, and + // max, and omission of alpha1, side_input, alpha2, or bias. + + HloInstruction* relu_input; + + // Match max(0, relu_input). + auto zero_pattern = Broadcast(match::ConstantScalar(0)); + if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && + !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { + return absl::nullopt; + } + HloInstruction* conv_instr = nullptr; + HloInstruction* alpha_conv_instr = nullptr; + HloInstruction* alpha_side_input_instr = nullptr; + HloInstruction* bias_broadcast_instr = nullptr; + HloInstruction* bias = nullptr; + HloInstruction* side_input = nullptr; + + // These nodes will not be in the returned value, but we need to check them + // for single use. + HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, + *mul1 = nullptr, *mul2 = nullptr; + + const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); + const auto conv_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto conv_pattern = GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); + return AnyOf( + MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); + }(); + const auto side_input_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + // If bias is already matched, match arbitrary additional input as side + // input. Note this may force a cheap operation (e.g. broadcast) to be + // materialized into a large buffer, as large as the output buffer. + // + // TODO(timshen): If in practice there are significant false positives, we + // should fix it. + auto side_input_pattern = Op(&side_input); + return AnyOf( + MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), + side_input_pattern); + }(); + + { + // Try to match any of the following form of add, in any association: + // addends[0] + // addends[0] + addends[1] + // addends[0] + addends[1] + addends[2] + // + // Then try to match each addend with one of the three patterns: bias, conv, + // or side_input. Notice that side_input matching must go last, as it + // also matches a conv or a bias. + HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; + auto add3_pattern = [&] { + auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); + return AnyOf( + AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, + Op(&addends[0])); + }(); + CHECK(Match(relu_input, add3_pattern)); + for (auto addend : addends) { + if (addend) { + if (bias == nullptr && Match(addend, bias_pattern)) { + CHECK(bias); + } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { + CHECK(conv_instr); + } else if (side_input == nullptr && Match(addend, side_input_pattern)) { + CHECK(side_input); + } else { + return absl::nullopt; + } + } + } + } + + if (conv_instr == nullptr) { + return absl::nullopt; + } + + for (HloInstruction* instr : + {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { + if (instr && instr->user_count() > 1) { + return absl::nullopt; + } + } + + auto conv = Cast(conv_instr); + auto bias_broadcast = + CastOrNull(bias_broadcast_instr); + + if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { + return absl::nullopt; + } + + if (bias_broadcast) { + // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. + if (bias_broadcast_instr->dimensions().size() != 1) { + return absl::nullopt; + } + if (bias_broadcast_instr->dimensions(0) != + conv->convolution_dimension_numbers().output_feature_dimension()) { + return absl::nullopt; + } + } + + return ConvWithRelu{ + instr, + conv, + bias, + side_input, + CastOrNull(alpha_conv_instr), + CastOrNull(alpha_side_input_instr)}; +} + +StatusOr> TryRewriteToCudnnForwardRelu( + ConvWithRelu match) { + auto conv = match.conv; + + HloComputation* computation = conv->parent(); + PrimitiveType element_type = conv->operand(0)->shape().element_type(); + + const auto get_alpha_value = + [](HloConstantInstruction* instr) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto alpha, + Cast(instr)->literal().Convert(F64)); + return alpha.GetFirstElement(); + }; + + double alpha_conv = 1; + if (match.alpha_conv) { + TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); + } + + double alpha_side_input; + if (match.side_input) { + if (match.alpha_side_input) { + TF_ASSIGN_OR_RETURN(alpha_side_input, + get_alpha_value(match.alpha_side_input)); + } else { + alpha_side_input = 1; + } + } else { + CHECK(match.alpha_side_input == nullptr); + alpha_side_input = 0; + } + + auto bias = match.bias; + if (!bias) { + auto zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + + int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( + conv->convolution_dimension_numbers().output_feature_dimension()); + bias = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout(element_type, + {num_output_feature}), + zero, {})); + } + + CHECK(bias); + std::vector args = {conv->mutable_operand(0), + conv->mutable_operand(1), bias}; + if (match.side_input) { + args.push_back(match.side_input); + } + auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( + conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_window(conv->window()); + new_conv->set_convolution_dimension_numbers( + conv->convolution_dimension_numbers()); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + config.set_activation_mode( + static_cast(se::dnn::ActivationMode::kRelu)); + config.set_conv_result_scale(alpha_conv); + config.set_side_input_scale(alpha_side_input); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + + VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name(); + return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), + new_conv, 0); +} + +} // namespace + +StatusOr CudnnFusedConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector matches; + int num_forward_convs = 0; + for (auto instr : computation->instructions()) { + auto match = FindConvWithRelu(instr); + if (match.has_value()) { + matches.push_back(*match); + } + if (auto call = DynCast(instr)) { + if (call->custom_call_target() == kCudnnConvForwardCallTarget) { + num_forward_convs++; + } + } + } + VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() + << " out of " << num_forward_convs << " forward convs."; + std::vector>> + replacements; + for (const ConvWithRelu& match : matches) { + TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); + replacements.push_back({match.maximum, std::move(new_instr)}); + changed = true; + } + for (auto& replacement : replacements) { + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + replacement.first, std::move(replacement.second))); + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..bd12aadded9dd9e19bc695ddc11e5529931a306a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +class CudnnFusedConvolutionRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-fused-convolution-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index cc38db27e2680e950f74e104cef8829585c7b81c..c1aaa4bf04ddc31edf723c056805ae5aad994e55 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -43,16 +45,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gpu { +using absl::StrAppend; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrAppend; namespace { // Returns whether operand is a floating-point literal with the given value. @@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter( compute_nested_(std::move(compute_nested)) {} StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // The libdevice math functions differentiate between "double" and "float" by // appending an 'f' to the function's name. libdevice doesn't have f16 math // functions, so we convert the operands to f32 before calling the function @@ -94,7 +92,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { converted_operands[i] = - b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); + FPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -107,22 +105,20 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( break; default: return Unimplemented("Bad type for libdevice math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // llvm intrinsics differentiate between half/float/double functions via // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; @@ -138,22 +134,20 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( break; default: return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } return EmitMathCall(munged_callee, operands, input_types, output_type); } StatusOr GpuElementalIrEmitter::EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", - PrimitiveType_Name(input_type).c_str(), - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(input_type), + PrimitiveType_Name(output_type)); } } @@ -163,8 +157,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); @@ -183,8 +176,7 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( } StatusOr GpuElementalIrEmitter::EmitPowerOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { CHECK_EQ(op->opcode(), HloOpcode::kPower); PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); @@ -210,13 +202,15 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( return make_sqrt(); } - if (hlo_module_config_.debug_options().xla_enable_fast_math() && - IsFPLiteralWithValue(rhs, -.5)) { + if (IsFPLiteralWithValue(rhs, -.5)) { VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. + // + // TODO(jlebar): Does this happen with fastmath disabled? If not, should + // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -225,82 +219,74 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( } StatusOr GpuElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { + PrimitiveType prim_type, llvm::Value* value) { return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog1p( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitSin( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitCos( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExp( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExpm1( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); } StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType output_type = op->shape().element_type(); - switch (op->opcode()) { - case HloOpcode::kTanh: - // If we don't care much about precision, emit a fast approximation of - // tanh. - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - // Upcast F16 to F32 if necessary. - llvm::Type* type = - input_type == F16 ? b_->getFloatTy() : operand_value->getType(); - llvm::Value* input = b_->CreateFPCast(operand_value, type); - llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, operand_value->getType()); - } - return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, - output_type); - default: - return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); - } +StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { + // Emit a fast approximation of tanh instead of calling __nv_tanh. + // __nv_tanh is particularly bad because it contains branches, thus + // preventing LLVM's load-store vectorizer from working its magic across a + // function which contains tanh calls. + // + // This routine isn't numerically precise, but it's good enough for ML. + + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); + llvm::Value* input = FPCast(value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return FPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type, + absl::Span attributes) { std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( @@ -320,29 +306,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( callee->addFnAttr(attribute); } - return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return Call(callee, llvm_ir::AsArrayRef(operands)); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); +llvm::Value* GpuElementalIrEmitter::EmitThreadId() { + llvm::Value* block_id = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kMap: return [=, &operand_to_generator]( @@ -388,7 +373,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(init_value, accum_ptr); + Store(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -410,22 +395,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = b_->CreateNSWMul( + llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = b_->CreateNSWSub( - b_->CreateNSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + input_index[i] = + NSWSub(NSWAdd(stridden_index, window_index[i]), + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpULT( - input_index[i], - index_typed_const(operand->shape().dimensions(i)))); + in_bounds = + And(in_bounds, + ICmpULT(input_index[i], + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -437,14 +421,15 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b_->CreateLoad(accum_ptr), input_value})); - b_->CreateStore(accum_value, accum_ptr); + compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); + Store(accum_value, accum_ptr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return b_->CreateLoad(accum_ptr); + return Load(accum_ptr); }; case HloOpcode::kReduce: + // TODO(b/112040122): This should be supported. + CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { const HloInstruction* operand = hlo->operand(0); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index e3eacef133cb8b615a645ca2f11dd6dedf9f0176..e8b56a39ce58b6aab35c1c977553c7ff7e753273 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace gpu { @@ -38,9 +38,9 @@ namespace gpu { class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation - // given an ArraySlice of its input elements. + // given a Span of its input elements. using NestedComputer = std::function( - const HloComputation&, tensorflow::gtl::ArraySlice)>; + const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, @@ -48,85 +48,77 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: - StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - - StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; + StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value) override; StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; + + StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) override; - llvm::Value* EmitThreadId() const override; + llvm::Value* EmitThreadId() override; private: // Emits IR for op, which must have opcode kPower. StatusOr EmitPowerOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. llvm::Value* EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_type, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const; + const string& callee_name, absl::Span operands, + absl::Span input_type, PrimitiveType output_type, + absl::Span attributes); // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a libdevice function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. StatusOr EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 0cdddf8bcfd4e849b311bf810eda471d79dbf106..ca4a605af5d3b6b58b603d7ddad60ed9ae8a212f 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,11 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -43,8 +43,8 @@ StatusOr> FftScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) { } // namespace -FftThunk::FftThunk(FftType fft_type, - tensorflow::gtl::ArraySlice fft_length, +FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, @@ -213,7 +212,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, - FftTypeToString(fft_type_).c_str()); + FftTypeToString(fft_type_)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 8c53be5077b0c5a88d303c729457139c6cb800f1..2be50e08bd2b561b44245b20e1fb200e31e65a41 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -62,7 +62,7 @@ class FftThunk : public Thunk { public: // Constructs a thunk for launching an FFT on a stream. // Semantics of null hlo_instruction argument are as in Thunk. - FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice fft_length, + FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 2fd2206324e5f763490780a54880825a772b7ea2..88f0b4d71c915c37f0b58cb91a8788fd8f9cc452 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 3cd30b754c3242f00c704de1afab2282ed827b41..30c1f9088968305ad0207164ecb07ba13cc89ee6 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -18,12 +18,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -64,10 +66,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (absl::c_all_of( + instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -223,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - user->fusion_kind() == HloInstruction::FusionKind::kInput); + (user->fusion_kind() == HloInstruction::FusionKind::kInput && + LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() << ": Some of its users are not loop/input fusion kernels."; @@ -241,11 +245,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (c_any_of(fusion->fused_instructions(), - [](const HloInstruction* instruction) { - return instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction); - })) { + if (absl::c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { VLOG(3) << "Not merging " << fusion->name() << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; @@ -287,11 +291,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " - << tensorflow::str_util::Join(users, ", ", - [](string* out, HloInstruction* user) { - tensorflow::strings::StrAppend( - out, user->name()); - }) + << absl::StrJoin(users, ", ", + [](string* out, HloInstruction* user) { + absl::StrAppend(out, user->name()); + }) << " }"; // Remove 'fusion' instruction. CHECK_EQ(0, fusion->user_count()); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 4c523a66de977cd32423b25f0d165c4f4ba51c4a..f19996edfe3dd923aa686a19621ce28a4aed5a45 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -32,9 +32,9 @@ namespace gpu { // 2) The result of merging the fusion instruction into its users would not // increase bytes transferred. // -class FusionMerger : public HloPassInterface { +class FusionMerger : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "fusion merger"; } + absl::string_view name() const override { return "fusion merger"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index b22bb1d39ba177ef42673c7a3755694b43c15d14..7cc869ed9e89688d6ea06428a7bade3ebe55ea23 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { op::Fusion(op::Parameter())); } +TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { + auto module = ParseHloString(R"( + HloModule m + + f1_computation { + f1_p0 = f32[16,16,256]{0,1,2} parameter(0) + add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0) + // Note that the copy changes the layout from {0,1,2} to {2,1,0}. + ROOT f1_root = f32[16,16,256]{2,1,0} copy(add) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[16,16,256]{2,1,0} parameter(0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[16,16,256]{0,1,2} parameter(0) + f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 74282c568c09921dbeec2e9cce79b6c73b6ea592..9c4a4903667ea1a6c99ce9e912c9d0497b8e389f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -186,7 +186,7 @@ StatusOr DoGemmAutotune( } return InternalError( - "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms " "ran successfully", stream, algorithms.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 939c7f85e35b4fcb943a25aa6346d72798432920..12c81f9bfc6bfdac63edf9c826b835057107fa41 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,12 +52,12 @@ class GemmThunk : public Thunk { se::Stream* stream, HloExecutionProfiler* profiler) override; - // Returns true if we'll perform autotuning if run on the given stream. If - // so, we want the GPU to be quiescent during autotuning, so as not to - // introduce noise in our results. - bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override { - return autotune_results_.count( - stream->parent()->GetDeviceDescription().name()) != 0; + bool WillAutotuneKernel(se::Stream* stream) override { + // We will autotune this kernel if we don't already have a autotune result + // for the stream device. + return autotune_results_.find( + stream->parent()->GetDeviceDescription().name()) == + autotune_results_.end(); } private: @@ -75,6 +75,8 @@ class GemmThunk : public Thunk { // results. The map's value is the best algorithm we've found for this thunk // on this device, or an error if none of the algorithms worked and we should // use the regular gemm without an algorithm. + // + // TODO(b/112415150): Make this thread safe. std::unordered_map> autotune_results_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 75f414e47fe3edcc1b10b392ed5cc5038be6c190..79c74e7e8bf3a1aa59243b81942d29180bb46e74 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -34,15 +34,6 @@ namespace xla { namespace gpu { -StatusOr GpuCopyInsertion::FindOrInsertCopy( - HloInstruction* hlo) { - HloInstruction*& copy = hlo_to_copy_map_[hlo]; - if (copy == nullptr) { - TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); - } - return copy; -} - StatusOr GpuCopyInsertion::Run(HloModule* module) { CopyInsertion generic_copy_insertion; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 0c6f9b511f3aac5f62182273b827adcd068cd633..4c7e38ffeb60f87a4f27e212572ae31cca8e0947 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -25,20 +25,11 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public HloPassInterface { +class GpuCopyInsertion : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; - - protected: - // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted to materialize operands of library - // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index bb7736efa65c49766ea88a43fdab2b7102100c11..31a9f9b1beb81da81a06f6dc8e7c13c105514092 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -112,7 +112,7 @@ Status GpuExecutable::ExecuteThunks( // // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), // since we expect it to be an expensive call? - tensorflow::gtl::optional op_annotation; + absl::optional op_annotation; if (top_level_annotation.IsEnabled()) { op_annotation.emplace( thunk->hlo_instruction() != nullptr @@ -131,9 +131,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - // If this thunk requests it, wait for all currently-executing thunks to - // finish. This is useful e.g. if the thunk is about to perform autotuning. - if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + // If this thunk is about to autotune then wait for all currently executing + // thunks to finish. This reduces noise and thus the probability of + // choosing a suboptimal algorithm. + if (thunk->WillAutotuneKernel(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); } @@ -143,7 +144,7 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR( thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { - auto finish_event = MakeUnique(main_stream->parent()); + auto finish_event = absl::make_unique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); @@ -159,7 +160,7 @@ Status GpuExecutable::ExecuteThunks( if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", - main_stream, block_status.error_message().c_str()); + main_stream, block_status.error_message()); } } @@ -233,7 +234,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { DeviceMemoryAllocator* memory_allocator = run_options->allocator(); @@ -259,10 +260,9 @@ StatusOr GpuExecutable::ExecuteOnStream( if (buffer.is_null() && buffer.size() > 0) { return FailedPrecondition( "Cannot run XLA computation because pointer to (sub-)buffer at " - "index %s of parameter %lld was null. All pointers to " - "(sub-)buffers must not be null, unless the (sub-)buffer has zero " - "elements.", - allocation.param_shape_index().ToString().c_str(), param_no); + "index %s of parameter %d was null. All pointers to (sub-)buffers " + "must not be null, unless the (sub-)buffer has zero elements.", + allocation.param_shape_index().ToString(), param_no); } buffer_allocations_builder.RegisterBuffer(i, buffer); @@ -325,7 +325,7 @@ StatusOr GpuExecutable::ExecuteOnStream( StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index c7ce6d0acbbbe594040271c0d45c71c016e36514..38b0f8f15bd28cf2659e4a53b6634e981545716b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -32,10 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -78,12 +78,12 @@ class GpuExecutable : public Executable { // match the compute capability passed to this object's constructor. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; private: // If `block_host_until_done` is false, execution will not block the host diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d31fd5570c468b0c42fa308535fd335f3588a79 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { + +namespace { +void AppendParams(const HloInstruction& instr, + std::vector* params) { + if (instr.opcode() == HloOpcode::kFusion) { + params->insert(std::end(*params), std::begin(instr.fused_parameters()), + std::end(instr.fused_parameters())); + } else { + for (HloInstruction* operand : instr.operands()) { + params->push_back(operand); + } + } +} +} // namespace + +bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, + const HloInstruction& reduce) { + std::vector params; + AppendParams(producer, ¶ms); + AppendParams(reduce, ¶ms); + int64 max_rank = -1; + const Layout* max_rank_layout; + for (HloInstruction* param : params) { + if (ShapeUtil::IsArray(param->shape()) && + ShapeUtil::Rank(param->shape()) > max_rank) { + max_rank = ShapeUtil::Rank(param->shape()); + max_rank_layout = ¶m->shape().layout(); + } + } + return absl::c_all_of(params, [&](HloInstruction* param) { + return (!ShapeUtil::IsArray(param->shape())) || + (ShapeUtil::Rank(param->shape()) < max_rank) || + (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + }); +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + if (instr.IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr.fused_expression_root()->operands()) { + if (IsReductionToVector(*operand)) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Multi-output fusion rooted at reduction-to-vector ops must be " + "of kind kInput: " + << instr.ToString(); + return true; + } + } + return false; + } else if (instr.opcode() == HloOpcode::kFusion) { + if (IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; + } + return false; + } + return IsReductionToVector(instr); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h new file mode 100644 index 0000000000000000000000000000000000000000..f7c24a0d5bbfcc61389ea19ae7f769671e4e974d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from +// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion. + +namespace xla { +namespace gpu { + +// The code emitted for reduce-rooted input fusions (EmitReductionToVector) +// suffers from poor data locality if the layouts of input parameters differ. In +// such situtations it is better not to fuse. Only input params with +// maximum rank are considered. Params with smaller ranks will be broadcasted +// and have not been observed to cause data locality issues. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. +bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, + const HloInstruction& reduce); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op, an input fusion rooted at a +// reduction-to-vector op, or a multi-output input fusion with at least one +// reduction-to-vector op root. +// Note that reduction ops are lowered in different ways. Reduce input fusions +// are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at +// reduction-to-vector ops. Other reduction ops are lowered by +// GpuElementalIrEmitter and fused like elementwise ops. +bool IsInputFusibleReduction(const HloInstruction& instr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91b7bc61fda5a07c163a07ec0e1644d2ad9db49 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +using GpuFusibleTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + scalar_add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + })"; + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_ElementwiseProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + ROOT reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(exp->opcode(), HloOpcode::kExp); + EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*exp, *reduce)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_MixedLayoutProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(1); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect); + EXPECT_FALSE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + +TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce { + p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) + c0.1 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(p0.1, c0.1), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + copy = f32[128,1024,32,32]{1,3,2,0} copy(p0) + ROOT reduce_fusion = f32[1024]{0} fusion(copy), kind=kInput, calls=fused_reduce + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->fused_expression_root()->opcode(), HloOpcode::kReduce); + const HloInstruction* copy = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(copy->opcode(), HloOpcode::kCopy); + EXPECT_FALSE(LayoutsAreReduceInputFusionFriendly(*copy, *reduce)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_LayoutChangingFusionProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + layout_changing_computation { + p0.1 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast) + select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=layout_changing_computation + ROOT reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kCopy); + EXPECT_FALSE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + broadcasting_computation { + p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f32[128]{0} parameter(1) + broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0} + ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast) + } + ENTRY entry { + p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1 = f16[128]{0} parameter(1) + loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation + c0.2 = f32[] constant(0) + ROOT reduce = f32[128,1024]{0,1} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kAdd); + EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + // Reduction-to-vector lowered by IrEmitterUnnested. + ROOT reduce = f32[512]{0} reduce(p1, c0), dimensions={0,2,3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + c0 = f32[] parameter(0) + p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) + // Reduction lowered by GpuElementalIrEmitter. + ROOT reduce = f32[8,512,5,1,1]{4,3,2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + ROOT reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = f32[128,512]{1,0} fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) + ROOT reduce = f32[8,5,1,1]{3,2,1,0} reduce(p1, c0), dimensions={1,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(0) + ROOT fusion = f32[8,5,1,1]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce.0 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + reduce.1 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + ROOT root = (f32[128,512]{1,0}, f32[128,512]{1,0}) tuple(reduce.0, reduce.1) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[128,512]{1,0}, f32[128,512]{1,0}) fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, + IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1) + ROOT root = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce.0 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + reduce.1 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + ROOT root = (f32[512,28]{1,0}, f32[512,28]{1,0}) tuple(reduce.0, reduce.1) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[512,28]{1,0}, f32[512,28]{1,0}) fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, + IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1) + ROOT root = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc similarity index 91% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule.cc rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 19de37b0fbed15455e8c6a9bfe427ba3d9f0a9dc..02a0d028c118aba23996f9b97d05443bb4a00c88 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -17,12 +17,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -59,8 +60,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = - MakeUnique>(thunk_launch_order); + entry_sequence_ = absl::make_unique>( + thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -75,7 +76,7 @@ GpuHloOrdering::GpuHloOrdering( // same-stream predecessors of each instruction. // Compute the set of all instructions we will want to set reachability on. - auto predecessor_map = MakeUnique( + auto predecessor_map = absl::make_unique( module->entry_computation()->MakeInstructionPostOrder()); // The most recently visited instruction per stream. @@ -184,13 +185,13 @@ void BFSLaunchOrder(const HloComputation* computation, } // end namespace -HloSchedule::HloSchedule() {} +GpuHloSchedule::GpuHloSchedule() {} /* static */ -StatusOr> HloSchedule::Build( +StatusOr> GpuHloSchedule::Build( const HloModule& module, const StreamAssignment& stream_assignment, int64 pointer_size) { - std::unique_ptr schedule(new HloSchedule); + std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. const HloComputation* entry_computation = module.entry_computation(); @@ -198,17 +199,18 @@ StatusOr> HloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - schedule->thunk_launch_order_, - ScheduleOneComputation( + HloInstructionSequence sequence, + ScheduleComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); + schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); } - schedule->hlo_ordering_ = MakeUnique( + schedule->hlo_ordering_ = absl::make_unique( &module, stream_assignment, schedule->thunk_launch_order_); return std::move(schedule); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h similarity index 78% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule.h rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 1ce7a48ac8fcbbad0b3697845681582fe806b322..07a7fc67aa555845c3de57e574ab582403ec0490 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ #include #include @@ -33,12 +33,14 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. -class HloSchedule { +// launch order. This class differs from xla::HloSchedule in that HloSchedule +// represents a total order of all instructions in the module for backends which +// execute HLO instructions strictly sequentially. +class GpuHloSchedule { public: - // Constructs an HloSchedule for the given module, based on the given stream - // assignment. - static StatusOr> Build( + // Constructs an GpuHloSchedule for the given module, based on the given + // stream assignment. + static StatusOr> Build( const HloModule& module, const StreamAssignment& stream_assignment, int64 pointer_size); @@ -56,7 +58,7 @@ class HloSchedule { } private: - HloSchedule(); + GpuHloSchedule(); std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; @@ -65,4 +67,4 @@ class HloSchedule { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc similarity index 85% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 45f0a1c645b2875cf90d2c11cfb66c3dd855d097..b857fa775a76ec999b505a2a64332cc0c54cf00b 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -13,32 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class HloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloVerifiedTestBase { protected: using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); - static std::unique_ptr BuildHloSchedule( + static std::unique_ptr BuildGpuHloSchedule( const HloModule& module, const StreamAssignment& streams) { - return HloSchedule::Build(module, streams, /*pointer_size=*/8) + return GpuHloSchedule::Build(module, streams, /*pointer_size=*/8) .ConsumeValueOrDie(); } @@ -47,7 +49,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } HloVec RemoveHlo(const HloVec& input, @@ -64,7 +66,7 @@ class HloScheduleTest : public HloTestBase { // Test of a single stream, where data dependencies fully determine the // execution order. -TEST_F(HloScheduleTest, SequentialMatMul) { +TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -72,10 +74,10 @@ TEST_F(HloScheduleTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -84,7 +86,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) { EXPECT_EQ(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), HloVec({dot1, dot2})); @@ -122,7 +124,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) { // Test of a single stream, where data dependencies do not fully determine the // execution order, but the stream assignment does. -TEST_F(HloScheduleTest, SequentialAdd) { +TEST_F(GpuHloScheduleTest, SequentialAdd) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -146,7 +148,7 @@ TEST_F(HloScheduleTest, SequentialAdd) { EXPECT_EQ(streams->StreamNumberForHlo(*add1), streams->StreamNumberForHlo(*add3)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), HloVec({add1, add2, add3})); @@ -194,18 +196,18 @@ TEST_F(HloScheduleTest, SequentialAdd) { } // Test of two streams. -TEST_F(HloScheduleTest, ConcurrentMatMul) { +TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* add = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -214,7 +216,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { EXPECT_NE(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y}); EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) || @@ -250,7 +252,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { } // Test of multiple streams. -TEST_F(HloScheduleTest, LatticeMatMul) { +TEST_F(GpuHloScheduleTest, LatticeMatMul) { // d00 -- layer 0 // / \ // d10 d11 -- layer 1 @@ -265,26 +267,26 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); @@ -306,7 +308,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // We don't check the thunk launch order, since there are many valid total // orders, and it's annoying to express. - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); auto order = schedule->ConsumeHloOrdering(); const HloVec all_params( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4944c41f7d8dc7a78a3cd094aee4d7087c74857e..4268fb2c7a813b3b53e4cd48746028a7b369f28e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr GpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "GPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index d63e213d2b1efab4bcff75541cc5ab33d7a07976..9c64b4d10c9d1b172f7bd89b5fdacda893488bf8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -23,14 +23,12 @@ namespace xla { // his pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the GPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class GpuHloSupportChecker : public HloPassInterface { +class GpuHloSupportChecker : public HloModulePass { public: GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "gpu_hlo_support_checker"; - } + absl::string_view name() const override { return "gpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 0a4089df4c954cafcbe241189ee79a0995683513..27a4d0b601f3807fe6b94dd6171a44f292921ede 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloTestBase { +class GpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index d033faee8d25ed81a1483f8314652ef999ab36c5..74352f26aa9c3a2ca597da21735438df92f863ab 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -21,8 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -90,27 +92,33 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // operands and the output shape. Depending on the underlying algorithm, one of // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints) { - CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); - Shape input_shape; - Shape filter_shape; - Shape output_shape; - const auto& target = instr->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->shape().tuple_shapes(0); - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_shape = instr->shape().tuple_shapes(0); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->operand(0)->shape(); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->shape().tuple_shapes(0); - output_shape = instr->operand(1)->shape(); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + HloCustomCallInstruction* instr, LayoutConstraints* constraints) { + Shape lhs_shape = instr->operand(0)->shape(); + Shape rhs_shape = instr->operand(1)->shape(); + Shape result_shape = instr->shape().tuple_shapes(0); + + Shape* input_shape; + Shape* filter_shape; + Shape* output_shape; + + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + input_shape = &lhs_shape; + filter_shape = &rhs_shape; + output_shape = &result_shape; + break; + case CudnnConvKind::kBackwardInput: + input_shape = &result_shape; + filter_shape = &rhs_shape; + output_shape = &lhs_shape; + break; + case CudnnConvKind::kBackwardFilter: + input_shape = &lhs_shape; + filter_shape = &result_shape; + output_shape = &rhs_shape; + break; } { @@ -127,8 +135,9 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( } TF_ASSIGN_OR_RETURN( - std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), - *output_shape.mutable_layout()), + std::tie(*input_shape->mutable_layout(), + *filter_shape->mutable_layout(), + *output_shape->mutable_layout()), StreamExecutorConvLayoutsToXlaLayouts( instr->convolution_dimension_numbers(), input, filter, output)); } @@ -141,24 +150,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( instr, /*index=*/{0})); // Set layouts of the instructions' shapes. - if (target == kCudnnConvForwardCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardInputCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(result_shape.layout(), *call_result_buf)); + // instr->operand(2), if exists, is the bias buffer. There is no need to + // assign layout to it, as it has only one dimension. + + // instr->opernad(3), if exists, is the side input buffer. + if (instr->operand_count() == 4) { + if (kind != CudnnConvKind::kForwardActivation) { + return InternalError( + "Invalid convolution. Conv has a side input, but kind is not fused " + "conv forward: %s", + instr->ToString()); + } + // The side input layout must match the output layout. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3)); } return Status::OK(); } @@ -173,8 +181,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( ++iterator) { HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { - TF_RETURN_IF_ERROR( - AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); + TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( + Cast(instruction), constraints)); } // For batched dot we require the default layout. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f..e2b96a81d4de1337de2978a9d3c6c38c6e5fd0cd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -47,7 +48,7 @@ class GpuLayoutAssignment : public LayoutAssignment { private: Status AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints); + HloCustomCallInstruction* instr, LayoutConstraints* constraints); se::StreamExecutor* stream_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 286547ebae2f1a4b8d783a06d13b4dd96052b952..fbc8ddf599570b90e93eb463a1fd6c275b73711c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { // Enumerate all combinations of shapes. for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { for (int constrained_param_no : {0, 4}) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 79b3f1efecdf06bfa93b17a1799f3009d517f3b5..f3c274429242d5c989146d14ea523b5910408cff 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/DataLayout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -83,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } infeed_manager->EnqueueDestination(std::move(buffers)); @@ -96,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size == 0) { @@ -117,38 +118,37 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( return std::move(buffer); } -static std::unique_ptr ShapeTreeToLiteral( +static void ShapeTreeToLiteral( ShapeTree>* shape_tree) { // This is a struct instead of a lambda for std::function-free recursion. struct Helper { - static std::unique_ptr helper( + static void helper( ShapeTree>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); if (ShapeUtil::IsArray(shape)) { - return (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + return; } CHECK(ShapeUtil::IsTuple(shape)) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); - std::vector> tuple_operands; for (int64 i = 0; i < tuple_element_count; ++i) { index->back() = i; - tuple_operands.push_back(helper(shape_tree, index)); + helper(shape_tree, index); } index->pop_back(); - return LiteralUtil::MakeTupleOwned(std::move(tuple_operands)); } }; ShapeIndex index; - return Helper::helper(shape_tree, &index); + Helper::helper(shape_tree, &index); } Status GpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* /*executor*/, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { ShapeTree> outfeed_buffers( &literal_shape); @@ -161,7 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( if (ShapeUtil::IsTuple(shape)) { return; } - *buffer = MakeUnique(GetByteSizeRequirement(shape)); + *buffer = absl::make_unique( + GetByteSizeRequirement(shape)); + (*buffer)->set_destination( + absl::make_unique(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -169,8 +172,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); outfeed_manager->EnqueueDestination(&outfeed_buffers); - // Now turn the tree of buffers back into a literal. - *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers)); + // Now wait for the tree of buffers are written. + ShapeTreeToLiteral(&outfeed_buffers); return Status::OK(); } @@ -178,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace xla static std::unique_ptr CreateNVPTXTransferManager() { - return xla::MakeUnique( + return absl::make_unique( /*id=*/stream_executor::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) .getPointerSize(0 /* default address space */)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index dceeb9e2eb01a7dd5e978d819ed1db56d828f353..fa88816bc8b0bf41f05358c0089b381305ed3182 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ #include @@ -42,7 +42,7 @@ class GpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be @@ -61,4 +61,4 @@ class GpuTransferManager : public GenericTransferManager { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 17226769302eef0dd01550b0bc5404e889ad78f8..b9c21e8edb2bdde03acb1fe6197a399724c9c8ab 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,7 @@ namespace gpu { namespace { void InitAndStartTimer(std::stack>* timers, se::Stream* stream) { - timers->push(MakeUnique(stream->parent())); + timers->push(absl::make_unique(stream->parent())); stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); } @@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler( CHECK(hlo_instructions_.insert(hlo_instruction).second) << hlo_instruction->name(); } - return MakeUnique(this, hlo_instruction); + return absl::make_unique(this, hlo_instruction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 8c11cd05419289d82b033c936bb60884f45cb636..51627402b45f594dab3480129ba182d54d01b811 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -24,20 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos) { + absl::Span io_hlos, + absl::Span non_io_hlos) { // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index eee40b0e91fc03013a6978ae3cfe42b87633eed7..c0edae530cedba45c897b07b7b9cc72eaaab397c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace gpu { @@ -45,8 +45,8 @@ class HloToIrBindings { alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} void EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos); + absl::Span io_hlos, + absl::Span non_io_hlos); // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index c5f0cdf6cd5d3e076bffa875fbba991bf0681ee8..a4364b0deb6c97b7b580e18bf67d5f3a8fd3cc62 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" namespace xla { namespace gpu { @@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; - host_to_device_stream_ = MakeUnique(executor); + host_to_device_stream_ = absl::make_unique(executor); host_to_device_stream_->Init(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index fee6d2af3bfd4976f5845edf592e8310b55a3feb..8c3a026740851767855beae59d6a3c92f7a0d6bd 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Infeeding to GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 0f2c83aeb2633a007559d8caac78ea2d233539ed..4d5d8e99f88149aabfd0a4aeafc7e6724d29418d 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -26,7 +27,7 @@ namespace gpu { namespace { -bool IsFusile(const HloInstruction& hlo) { +bool IsFusible(const HloInstruction& hlo) { // Don't fuse get-tuple-element on GPU: We can, but it's slower than not // fusing. We never generate kernels for unfused GTEs. Instead, if an // unfused GTE is an input to a kernel (including a fusion kernel), we @@ -41,7 +42,7 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || @@ -221,6 +222,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Do not fuse into reduce input fusions if the resulting kernel would suffer + // from poor data locality (due to unfriendly input layouts). + if (IsInputFusibleReduction(*consumer) && + !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { + return false; + } + // We can't fuse library calls, so if a user of such an op could become a // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for // further rationale. @@ -245,7 +253,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - if (!IsFusile(*producer) || !IsFusile(*consumer) || + if (!IsFusible(*producer) || !IsFusible(*consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8..96bfe0c12eb9cd6ef25804d6b34767471616f7e4 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -171,6 +172,78 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + constant.1 = f32[] constant(0) + ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + fused_reduce { + p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0) + mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1) + c0.1 = f32[] constant(0) + ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce + ROOT root = (f32[]) tuple(fusion) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), op::Add(op::Copy(), op::Copy())); +} + TEST_F(InstructionFusionTest, BitcastIntoAdd) { auto module = ParseHloString(R"( HloModule test_module @@ -365,7 +438,7 @@ static StatusOr FindHloInstruction( } return NotFound( "Computation '%s' does not contain an instruction with op code '%s'.", - computation.name().c_str(), HloOpcodeString(op).c_str()); + computation.name(), HloOpcodeString(op)); } TEST_F(InstructionFusionTest, MultiOutputFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index c349063c71f000435a05306101ad724505f2d197..ec3d8f9405840bb7be97ba5cd5725a4ac68a15a8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -128,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget = "__cudnn$convBackwardInput"; const char* const kCudnnConvBackwardFilterCallTarget = "__cudnn$convBackwardFilter"; +const char* const kCudnnConvBiasActivationForwardCallTarget = + "__cudnn$convBiasActivationForward"; bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() != HloOpcode::kCustomCall) { @@ -136,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { const auto& target = hlo.custom_call_target(); return target == kCudnnConvForwardCallTarget || target == kCudnnConvBackwardInputCallTarget || - target == kCudnnConvBackwardFilterCallTarget; + target == kCudnnConvBackwardFilterCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget; } bool ImplementedAsLibraryCall(const HloInstruction& hlo) { @@ -144,51 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -static HloInstruction* CreateCudnnConv( - const char* call_target, const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { - HloComputation* computation = lhs->parent(); - - // This call returns a tuple of (conv_result, scratch_memory), where - // conv_result is the actual result of the convolution, and scratch_memory is - // temporary memory used by cudnn. - // - // At the moment, we don't know how much scratch memory this conv is going to - // use, so we put u8[0] in this place. Later on another pass will choose - // which conv algorithm to use, and at that point we'll modify the shape of - // this second tuple element. - Shape call_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - - HloInstruction* custom_call = computation->AddInstruction( - HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); - custom_call->set_window(window); - custom_call->set_convolution_dimension_numbers(dnums); - return custom_call; -} - -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums) { - return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, - window, dnums); -} - -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums) { - return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, - reverse_filter, window, dnums); -} - -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, - output, window, dnums); -} - bool IsReductionToVector(const HloInstruction& reduce) { if (HloOpcode::kReduce != reduce.opcode()) { return false; @@ -215,8 +174,8 @@ bool IsReductionToVector(const HloInstruction& reduce) { // This emits a device-side call to // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, - tensorflow::gtl::ArraySlice arguments, +llvm::Value* EmitPrintf(absl::string_view fmt, + absl::Span arguments, llvm::IRBuilder<>* builder) { std::vector argument_types; for (auto argument : arguments) { @@ -279,5 +238,36 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } +StatusOr GetCudnnConvKind( + const HloCustomCallInstruction* instr) { + absl::string_view target = instr->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + return CudnnConvKind::kForward; + } + if (target == kCudnnConvBackwardInputCallTarget) { + return CudnnConvKind::kBackwardInput; + } + if (target == kCudnnConvBackwardFilterCallTarget) { + return CudnnConvKind::kBackwardFilter; + } + if (target == kCudnnConvBiasActivationForwardCallTarget) { + return CudnnConvKind::kForwardActivation; + } + return InternalError("Unexpected call target: %s", target); +} + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + case CudnnConvKind::kForwardActivation: + return "forward with activation"; + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 5d23a3d01842c7b4ff405171cd49c96a19f7e5b0..a64a616ab1329422d0197f4a7f99ec557a95f8ed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -28,6 +29,33 @@ limitations under the License. namespace xla { namespace gpu { +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter + kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + + // (optionally) side_input) => output +}; + +StatusOr GetCudnnConvKind(const HloCustomCallInstruction* instr); + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. @@ -93,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); extern const char* const kCudnnConvForwardCallTarget; extern const char* const kCudnnConvBackwardInputCallTarget; extern const char* const kCudnnConvBackwardFilterCallTarget; +extern const char* const kCudnnConvBiasActivationForwardCallTarget; // Returns true if `hlo` will be implemented as a call to a cuDNN convolution // routine. @@ -102,23 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget; // kConvolution opcode. bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); -// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. -// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If -// you want just the conv result, you'll need to get-tuple-element the value -// returned by this function. -// -// The created cudnn call will use the default cudnn algorithm and no scratch -// space. -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums); -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums); -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums); - // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); @@ -126,8 +138,8 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo); bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, - tensorflow::gtl::ArraySlice arguments, +llvm::Value* EmitPrintf(absl::string_view fmt, + absl::Span arguments, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 541cacf6970453033c09a153a2dd320d4ebf3d4a..b7c37bcf3ca910f10d18339dfe7f1d29f2a55c9e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -64,7 +65,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, hlo_module_config_(hlo_module_config) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math())); + .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -140,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { Status IrEmitter::EmitCallToNestedComputation( const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output) { + absl::Span operands, llvm::Value* output) { TF_RET_CHECK(nested_computation.num_parameters() > 0); llvm::Function*& emitted_function = computation_to_ir_function_[&nested_computation]; @@ -155,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - b_.CreateCall(emitted_function, arguments); + Call(emitted_function, arguments); return Status::OK(); } @@ -177,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( computation.root_instruction()->shape().element_type(); bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; - llvm::Value* source = b_.CreateLoad(source_address, "source"); + llvm::Value* source = Load(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -189,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } if (is_atomic_integral) { // integral + integral - b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } @@ -201,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -211,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -291,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); + llvm::Value* cas_old_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); @@ -308,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = PtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = And(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = And(atomic_memory_address, mask); atomic_memory_address = - b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = b_.CreateAdd( - b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); + IntToPtr(atomic_memory_address, atomic_address_type); binop_output_address = - b_.CreateIntToPtr(binop_output_address, element_address_type); + Add(PtrToInt(cas_new_output_address, address_int_type), offset); + binop_output_address = IntToPtr(binop_output_address, element_address_type); } else { - atomic_memory_address = - b_.CreateBitCast(output_address, atomic_address_type); + atomic_memory_address = BitCast(output_address, atomic_address_type); binop_output_address = - b_.CreateBitCast(cas_new_output_address, element_address_type); + BitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. - llvm::Value* cas_old_output = - b_.CreateLoad(atomic_memory_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_old_output_address); + llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output"); + Store(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( b_.GetInsertPoint(), "atomic_op_loop_exit"); @@ -343,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = Load(cas_old_output_address, "cas_old_output"); + Store(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); - llvm::Value* cas_new_output = - b_.CreateLoad(cas_new_output_address, "cas_new_output"); + llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = b_.CreateAtomicCmpXchg( - atomic_memory_address, cas_old_output, cas_new_output, - llvm::AtomicOrdering::SequentiallyConsistent, - llvm::AtomicOrdering::SequentiallyConsistent); + llvm::Value* ret_value = + AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. @@ -383,8 +378,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( // TODO(b/30258929): We only accept binary computations so far. return Unimplemented( "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); + "computation %s has %d.", + computation.name(), computation.num_parameters()); } if (MaybeEmitDirectAtomicOperation(computation, output_address, @@ -471,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = b_.CreateInsertValue(result, value.first, {0}); - result = b_.CreateInsertValue(result, value.second, {1}); + result = InsertValue(result, value.first, {0}); + result = InsertValue(result, value.second, {1}); } else { - result = b_.CreateFMul(lhs_value, rhs_value); + result = FMul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -518,7 +513,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // We don't have to iterate over the batch dimensions in both arrays, simplify // the loop nest of the rhs. for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { - DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); rhs_index[i] = lhs_index[i]; } @@ -558,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); - llvm::Value* accum = b_.CreateLoad(accum_address); + llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_element, rhs_element, &b_); llvm::Value* accum_real = Real(accum, &b_); - llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); - updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* real_sum = FAdd(accum_real, value.first); + updated_accum = InsertValue(accum, real_sum, {0}); llvm::Value* accum_imag = Imag(accum, &b_); - llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); - updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); + llvm::Value* imag_sum = FAdd(accum_imag, value.second); + updated_accum = InsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); - updated_accum = b_.CreateFAdd(accum, product); + llvm::Value* product = FMul(lhs_element, rhs_element); + updated_accum = FAdd(accum, product); } - b_.CreateStore(updated_accum, accum_address); + Store(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs @@ -594,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - b_.CreateLoad(accum_address), // The value written to the target array. + Load(accum_address), // The value written to the target array. &b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -632,19 +627,22 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { } Status IrEmitter::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported on GPU"); + } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + Alloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -681,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { *function, {accumulator_addr, input_address}, accumulator_addr)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); }); } @@ -748,14 +746,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleIota(HloInstruction*) { - // TODO(b/64798317): implement iota on GPU. - return Unimplemented("Iota is not implemented on GPU."); -} - StatusOr IrEmitter::ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements) { + absl::Span parameter_elements) { llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( computation.root_instruction()->shape().element_type(), module_), @@ -764,11 +757,26 @@ StatusOr IrEmitter::ComputeNestedElement( for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); - b_.CreateStore(parameter_element, parameter_buffers.back()); + Store(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return b_.CreateLoad(return_buffer); + return Load(return_buffer); +} + +std::vector IrEmitter::ConstructIrArrayForOutputs( + const HloInstruction& hlo) { + std::vector output_arrays; + if (ShapeUtil::IsTuple(hlo.shape())) { + int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays.reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays.push_back(GetIrArray(hlo, hlo)); + } + return output_arrays; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 561c6838798aa92ce2c96b3c45d5ba42fe6edef3..880520148005838cc25a5be9e26c8bc9028a70ce 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -35,13 +37,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -64,7 +65,8 @@ namespace gpu { // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR // generator generator. See comments on that class. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } + llvm::IRBuilder<>* builder() { return &b_; } + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter @@ -121,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } + + // Generates the IrArray for each output of an hlo instruction and returns + // a vector containing such IrArrays. + std::vector ConstructIrArrayForOutputs( + const HloInstruction& hlo); + // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( const HloInstruction& hlo, const ShapeIndex& index = {}) const { @@ -140,9 +149,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Emits a call in IR to the given nested computation with the given operands // and output. If no IR function has been previously emitted for the // computation, also emits such a function. - Status EmitCallToNestedComputation( - const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output); + Status EmitCallToNestedComputation(const HloComputation& nested_computation, + absl::Span operands, + llvm::Value* output); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` @@ -196,7 +205,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { StatusOr ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements); + absl::Span parameter_elements); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5c827e5f9cf3e1c04af444dae338a2ec411ce372..66c65f69758e5a2f4420935279835eaf086fea45 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { - const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); - std::vector target_arrays; - target_arrays.reserve(num_elems); - for (int64 i = 0; i != num_elems; ++i) { - target_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector target_arrays = + ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - - std::vector tuple_operand_ptrs; - tuple_operand_ptrs.reserve(num_elems); - for (const llvm_ir::IrArray& array : target_arrays) { - tuple_operand_ptrs.push_back(array.GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index d5ecae88ed519c7123b6231da981172a4a4de304..c792dd2ddb0faeba076548ba104aa291e0814140 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,6 +21,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -29,7 +35,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -56,7 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -68,6 +73,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -76,8 +82,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -85,13 +89,12 @@ namespace gpu { namespace { +using absl::InlinedVector; +using absl::nullopt; +using absl::optional; +using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; -using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::InlinedVector; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; -using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -173,7 +176,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args) { + absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( @@ -314,13 +317,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, }; // Check the size of input tensors - if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { return i64_ty; } // Check the size of the internal result tensors if (unnested_hlo->opcode() == HloOpcode::kFusion) { - if (!c_all_of( + if (!absl::c_all_of( unnested_hlo->fused_instructions_computation()->instructions(), hlo_shape_in_range)) { return i64_ty; @@ -383,7 +386,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { int64 feature_index_value = feature_index->literal().Get({}); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -413,7 +416,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -443,85 +446,37 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back(MakeUnique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + thunk_sequence_->emplace_back( + absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); - auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); - auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + std::vector operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config()); - const auto& target = custom_call->custom_call_target(); - std::unique_ptr thunk; - if (target == kCudnnConvForwardCallTarget) { - thunk = MakeUnique( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); - } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = MakeUnique( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = MakeUnique( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); - } - - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique( + Cast(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } @@ -545,12 +500,17 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { switch (root->opcode()) { case HloOpcode::kTuple: case HloOpcode::kReduce: { + if (root->opcode() == HloOpcode::kReduce && + ShapeUtil::IsTuple(root->shape())) { + // TODO(b/112040122): Support variadic reduce. + return Unimplemented("Variadic reduce is not supported on GPU"); + } VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; - ArraySlice output_instructions = + absl::Span output_instructions = root->opcode() == HloOpcode::kTuple ? root->operands() - : ArraySlice(&root, 1); + : absl::Span(&root, 1); // For multi-output fusion emit an initializer for each tuple element. // Otherwise it's sufficient to just initialize the single output. @@ -571,7 +531,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back( BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), fusion)); + absl::make_unique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); @@ -709,8 +669,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* reduce, const IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { const HloInstruction* output = reduce->parent()->FusionInstruction(); @@ -720,19 +679,18 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - b_.CreateStore(extra_output_ir_value, extra_output_address); + Store(extra_output_ir_value, extra_output_address); } return Status::OK(); } Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; @@ -793,8 +751,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), // // // // and threads_per_block is a multiple of warpSize. - // reduce_kernel<<>>(); - // + // reduce_kernel // auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = @@ -802,17 +759,17 @@ Status IrEmitterUnnested::EmitReductionToScalar( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { @@ -824,15 +781,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), + tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. @@ -841,11 +797,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( IrArray::Index input_index( /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -856,14 +812,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileSize), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileSize), + NSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); + Or(ICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); @@ -884,20 +840,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -912,10 +866,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = - b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); + URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { @@ -947,12 +900,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Divide the input matrix into tiles of size KxL. For example, when the // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like @@ -1035,12 +987,12 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1051,8 +1003,8 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); auto emit_tile_element_loop = [=](bool tile_in_y_bounds, bool tile_in_x_bounds) -> Status { @@ -1064,34 +1016,32 @@ Status IrEmitterUnnested::EmitColumnReduction( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* y = b_.CreateNSWAdd( - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); + llvm::Value* y = + NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), + tile_element_loop->GetIndVarValue()); // Unless we know that y is in bounds, we have to emit a check before // reading from the input. if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", - &b_); + ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); // Unless we know that x is in bounds, we have to emit a check before // reading from the input. if (!tile_in_x_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1118,7 +1068,7 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i * kTileWidth + x_offset], @@ -1133,20 +1083,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location // that's immediately beyond the tile. - llvm::Value* y_end = b_.CreateNSWAdd( - index_typed_constant(kTileHeight), - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + llvm::Value* y_end = + NSWAdd(index_typed_constant(kTileHeight), + NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location // that's immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileWidth), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileWidth), + NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); llvm::Value* tile_in_y_bounds = - b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); + Or(ICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); llvm::Value* tile_in_x_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); + Or(ICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. llvm_ir::LlvmIfData if_tile_in_y_bounds_data = @@ -1180,9 +1130,9 @@ Status IrEmitterUnnested::EmitColumnReduction( reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( @@ -1238,12 +1188,11 @@ static std::pair ComputeTilingSchemeForReduction( Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // A naive algorithm is: // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. @@ -1371,11 +1320,11 @@ Status IrEmitterUnnested::EmitRowReduction( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1384,22 +1333,20 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = ZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); + URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = b_.CreateNSWAdd( + llvm::Value* last_x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - index_typed_constant(x_tile_size - 1), - b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(index_typed_constant(x_tile_size - 1), + NSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( &b_, @@ -1411,9 +1358,8 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = b_.CreateNSWAdd( - z_indvar, - b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); + llvm::Value* z = + NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", /*start=*/index_typed_constant(0), @@ -1421,22 +1367,20 @@ Status IrEmitterUnnested::EmitRowReduction( /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = b_.CreateNSWAdd( + llvm::Value* x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - x_indvar, b_.CreateNSWMul( - warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(x_indvar, + NSWMul(warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), - "x_in_bounds", &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, &b_); @@ -1444,7 +1388,7 @@ Status IrEmitterUnnested::EmitRowReduction( // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1475,7 +1419,7 @@ Status IrEmitterUnnested::EmitRowReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1495,8 +1439,8 @@ Status IrEmitterUnnested::EmitRowReduction( }; llvm::Value* tile_in_bounds = - b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - b_.CreateICmpULT(last_x, index_typed_constant(width))); + Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1524,20 +1468,18 @@ Status IrEmitterUnnested::EmitRowReduction( for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1552,8 +1494,7 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = @@ -1599,13 +1540,12 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for @@ -1694,9 +1634,13 @@ Status IrEmitterUnnested::EmitReductionToVector( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support multi-output reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Multi-output reduce is not supported on GPU"); + } auto input = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); + absl::Span dimensions_to_reduce(reduce->dimensions()); HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that @@ -1709,7 +1653,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { thunks.push_back( BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), reduce)); + absl::make_unique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), {[&](const IrArray::Index& index) { @@ -1729,7 +1673,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = - c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { + absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() .GetUniqueTopLevelSlice(tuple_element) .ok(); @@ -1751,7 +1695,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(MakeUnique( + thunk_sequence_->emplace_back(absl::make_unique( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } @@ -1783,8 +1727,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); + thunk_sequence_->emplace_back(absl::make_unique( + std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1833,7 +1777,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, @@ -1854,15 +1798,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( + llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = b_.CreateNSWSub( - b_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_constant(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( + operand_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ICmpULT( operand_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -1872,7 +1816,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -1880,16 +1824,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently @@ -1905,11 +1849,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter( TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = b_.CreateLoad(select_return_buffer); + llvm::Value* result = Load(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), @@ -1918,7 +1862,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -1930,8 +1874,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) @@ -1963,19 +1907,13 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - auto result = CanTransformWhileToFor(xla_while); - if (result.ok()) { - auto tuple = result.ConsumeValueOrDie(); - // loop_trip_count = (limit - start + increment - 1) / increment - const int64 loop_trip_count = - (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / - std::get<2>(tuple); - thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + // TODO(b/112163966): Move trip count computation earlier in the pipeline. + if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { + thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name() - << " while-to-for transform status: " << result.status(); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } @@ -2015,7 +1953,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), rng)); + absl::make_unique(std::move(thunks), rng)); return Status::OK(); } @@ -2040,7 +1978,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { auto values_destination = GetAllocationSlice(*sort, values_shape_index); if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*keys), /*destination_buffer=*/keys_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); @@ -2048,7 +1986,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (values != nullptr && values_destination != GetAllocationSlice(*values)) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*values), /*destination_buffer=*/values_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); @@ -2092,15 +2030,15 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), - values != nullptr ? tensorflow::gtl::make_optional( + values != nullptr ? absl::make_optional( GetIrArray(*sort, *sort, values_shape_index)) - : tensorflow::gtl::nullopt, + : absl::nullopt, IrName(sort), xor_mask, &b_, &launch_dimensions)); } } thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), sort)); + absl::make_unique(std::move(thunks), sort)); return Status::OK(); } @@ -2127,7 +2065,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(MakeUnique( + thunk_sequence_->push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2142,17 +2080,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() .GetUniqueSlice(crs, {i}) .ValueOrDie()); - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. - thunks.push_back(MakeUnique(tuple_element_buffers, - GetAllocationSlice(*crs), nullptr)); + thunks.push_back(absl::make_unique( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( - MakeUnique(std::move(thunks), crs)); + absl::make_unique(std::move(thunks), crs)); return Status::OK(); } @@ -2302,7 +2240,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( for (const auto& kv : hlo_slices) { buffers_needed.insert(kv.second.first.allocation()); } - tensorflow::gtl::optional temp_buffer; + absl::optional temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { if (alloc.IsPreallocatedTempBuffer()) { if (!temp_buffer.has_value()) { @@ -2319,10 +2257,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. std::vector non_constant_buffers; - c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), - [](const BufferAllocation* allocation) { - return !allocation->is_constant(); - }); + absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { @@ -2361,8 +2299,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( *slice.allocation()))); CHECK_NE(loc, nullptr); } else { - loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -2370,8 +2308,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::Type* int8_double_pointer = llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = b_.CreateBitCast(loc, int8_double_pointer); - loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); + loc = BitCast(loc, int8_double_pointer); + loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2386,7 +2324,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique( + return absl::make_unique( non_constant_buffers, llvm_ir::AsString(kernel->getName()), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -2395,7 +2333,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique( + return absl::make_unique( /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2407,7 +2345,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique( + return absl::make_unique( /*source_address=*/GetAllocationSlice(*operand), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2427,7 +2365,7 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); }); - return MakeUnique(slices, inst); + return absl::make_unique(slices, inst); } std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( @@ -2444,7 +2382,7 @@ std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( *slice = status_or_slice.ConsumeValueOrDie(); } }); - return MakeUnique(std::move(slices), inst); + return absl::make_unique(std::move(slices), inst); } namespace { @@ -2467,7 +2405,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (inst->opcode() == HloOpcode::kDot) { const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2509,7 +2447,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2526,23 +2464,24 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique(inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + return absl::make_unique( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index) { + HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value_operand = [&] { + HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: - return inst->operand(2); + return inst->mutable_operand(2); case HloOpcode::kReduce: - return inst->operand(1); + return inst->mutable_operand(1); case HloOpcode::kTuple: CHECK(hlo->IsMultiOutputFusion()) << ": " << hlo->ToString() << " is not a multi-output fusion."; @@ -2550,7 +2489,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( << ": Found '" << inst->operand(index.back())->opcode() << "' in " << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. - return inst->operand(index.back())->operand(1); + return inst->mutable_operand(index.back())->mutable_operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2577,11 +2516,11 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // Are all the bytes of this scalar equal to 0? If so, we can create a // MemzeroThunk. - ArraySlice literal_bytes( + absl::Span literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); - if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return { - MakeUnique(GetAllocationSlice(*hlo, index), nullptr)}; + if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {absl::make_unique(GetAllocationSlice(*hlo, index), + nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2598,7 +2537,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique( + return {absl::make_unique( pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } @@ -2609,7 +2548,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique( + return {absl::make_unique( word, GetAllocationSlice(*hlo, index), nullptr)}; } } @@ -2622,28 +2561,35 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); - // If the init_value was fused into this reduce we have to generate it first. - if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { - CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - const Literal& literal = init_value_operand->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); + if (fused) { + // If init_value was fused into this reduce we have to generate it first. + std::vector parameter_arrays; + for (HloInstruction* operand : hlo->operands()) { + parameter_arrays.push_back(GetIrArray(*operand, *hlo)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - global_for_const->setAlignment(kConstantBufferAlignBytes); - bindings_.BindHloToIrValue(*init_value_operand, global_for_const); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } else { + // In the unfused case the element is already there, just read from it. + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &b_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &b_); - }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) - .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2667,8 +2613,7 @@ Status CheckHloBuffersShareAllocation( if (slice_a != slice_b) { return InternalError( "instruction %s %s does not share allocation with instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString()); } return Status::OK(); } @@ -2761,7 +2706,7 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2779,8 +2724,8 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique(loop_limit, - ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique( + loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); } std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( @@ -2800,7 +2745,7 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_); TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo->operand(1)), GetAllocationSlice(*hlo->operand(2)), @@ -2833,10 +2778,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( } // For multioutput fusion, we need to emit each operand and the root. - std::vector output_arrays; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector output_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2844,12 +2786,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel( &hlo, launch_dimensions.launch_bound(), &b_))); - std::vector tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); } @@ -2861,34 +2800,19 @@ Status IrEmitterUnnested::EmitTargetElementLoop( static_cast(LastThunk())); } -int IrEmitterUnnested::ConstructIrArrayForOutputs( - const HloInstruction& hlo, std::vector* output_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_arrays->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_arrays->push_back(GetIrArray(hlo, hlo, {i})); - } - } else { - output_arrays->push_back(GetIrArray(hlo, hlo)); - } - return num_outputs; -} - -int IrEmitterUnnested::ConstructIrArrayForInputs( - const HloInstruction& hlo, std::vector* param_arrays) { - int64 num_params = hlo.operands().size(); - param_arrays->reserve(num_params); +std::vector IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo) { + std::vector param_arrays; + param_arrays.reserve(hlo.operands().size()); for (const HloInstruction* param : hlo.operands()) { - param_arrays->push_back(GetIrArray(*param, hlo)); + param_arrays.push_back(GetIrArray(*param, hlo)); } - return num_params; + return param_arrays; } int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays) { int64 num_outputs = 1; @@ -2915,7 +2839,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays) { int64 num_params = hlo.operands().size(); @@ -3056,18 +2980,18 @@ void EmitTiledElementalCodeWithBoundsCheck( // TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient // to launch fewer blocks so each transposes many tiles. LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids) { + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { // Parameters for the tiling algorithm. constexpr int64 kTileSize = 32; constexpr int64 kNumRows = 4; constexpr int64 kThreadsPerTile = kTileSize * kNumRows; // Construct IrArrays for the inputs and outputs. - std::vector output_arrays; - int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); - std::vector param_arrays; - int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); + int64 num_outputs = output_arrays.size(); + std::vector param_arrays = ConstructIrArrayForInputs(*hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector param_shmem_buffers(num_params, nullptr); @@ -3102,7 +3026,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( CeilOfRatio(output_dims_in_tiles[i], kTileSize); } const int64 num_tiles = - c_accumulate(output_dims_in_tiles, 1, std::multiplies()); + absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); llvm::Type* index_ty = @@ -3148,9 +3072,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( const IrArray::Index output_tile_origin = [&] { IrArray::Index index = output_tile_index; for (int i = 1; i < 3; ++i) { - index[i] = - b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); + index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); } return index; }(); @@ -3163,12 +3086,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( std::vector output_tile_bounds(3); for (int i = 1; i < 3; ++i) { // Only last row or column may not have full size. - output_tile_bounds[i] = b_.CreateSelect( - b_.CreateICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); + output_tile_bounds[i] = + Select(ICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); } KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); @@ -3186,7 +3109,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // Adds `addend` to the given `dim` of `index`. auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = b_.CreateAdd(index[dim], addend); + index[dim] = Add(index[dim], addend); return index; }; const IrArray::Index input_index = @@ -3202,10 +3125,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( llvm::Value* shmem_buffer = param_shmem_buffers[id]; // TODO(jlebar): Add AA metadata to this store. Tile buffers are // global variables, so LLVM can't infer much about it. - b_.CreateStore( - input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); } }); @@ -3226,9 +3148,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_index, "output", output_tile_bounds[2], output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc) { // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( - b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); output_in_reduced_shape_arrays[0].EmitWriteArrayElement( index, load_from_shmem_buffer, &b_); }); @@ -3256,7 +3178,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_in_reduced_shape_arrays.size()); for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, b_.CreateExtractValue(output_value, i), &b_); + index, ExtractValue(output_value, i), &b_); } } else { output_in_reduced_shape_arrays[0].EmitWriteArrayElement( @@ -3267,12 +3189,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // For multioutput fusion, emit a tuple with all the individual outputs. if (hlo->IsMultiOutputFusion()) { - std::vector tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, - module_); + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); } return launch_dimensions; @@ -3305,7 +3222,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { if (!reduced_dims_021.has_value()) { reduced_dims_021 = curr_reduced_dims_021; } - if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) { + if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) { // There is more than one possible transpose. Instead of picking one // transpose, we simply give up here. return false; @@ -3338,7 +3255,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use - // shared memory in fusions. If in the future other fusile ops use shared + // shared memory in fusions. If in the future other fusible ops use shared // memory, we'll have to adjust this heuristic. constexpr int kMinBlocksPerCore = 3; constexpr int64 kShmemPerCore = 48 * 1024; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 525441990795e160ba0e8facb910d5cc9796c4bb..bd5db7205155dc6b15ddea069e172bbd8f419996 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -105,13 +105,12 @@ class IrEmitterUnnested : public IrEmitter { // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args); + absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens); // EmitColumnReduction and EmitRowReduction emit code for column and row @@ -127,12 +126,11 @@ class IrEmitterUnnested : public IrEmitter { Status EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a @@ -143,23 +141,21 @@ class IrEmitterUnnested : public IrEmitter { Status EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. Status EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which @@ -180,13 +176,12 @@ class IrEmitterUnnested : public IrEmitter { // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel @@ -195,18 +190,15 @@ class IrEmitterUnnested : public IrEmitter { // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // returns the launch dimensions for the kernel. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - LaunchDimensions EmitHlo021Tile( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids); - // Generates the IrArray for each output of hlo and returns the number of - // outputs. - int ConstructIrArrayForOutputs(const HloInstruction& hlo, - std::vector* output_arrays); - // Generates the IrArray for each input of hlo and returns the number of - // inputs. - int ConstructIrArrayForInputs(const HloInstruction& hlo, - std::vector* param_arrays); + LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, + absl::Span reduced_output_dims, + absl::Span tiled_param_ids); + + // Generates the IrArray for each input of an hlo and returns a vector that + // constains such IrArrays. + std::vector ConstructIrArrayForInputs( + const HloInstruction& hlo); + // For each output of the `hlo` instruction, constructs the reduced shape for // the output with the given `reduced_output_dims` and cast the original // output IrArray element in `output_arrays` to the reduced shape. Returns @@ -214,7 +206,7 @@ class IrEmitterUnnested : public IrEmitter { int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in @@ -226,7 +218,7 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays); @@ -250,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index = {}); + HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e76823ad103dfa5ba61a0d3ba81b2c028dfeb33e..e09b8fbd3ba275e14accbf88c21f3d10f34198d9 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -15,22 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction, - int unroll_factor) +KernelThunk::KernelThunk(absl::Span args, + const string& kernel_name, + const HloInstruction* hlo_instruction, + int unroll_factor) : Thunk(Kind::kKernel, hlo_instruction), args_(args.begin(), args.end()), kernel_name_(kernel_name), @@ -41,11 +41,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + loader_spec_->AddCudaPtxInMemory(executable.ptx(), kernel_name_); if (!executable.cubin().empty()) { loader_spec_->AddCudaCubinInMemory( @@ -63,7 +59,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + return InternalError("Unable to load kernel %s", kernel_name_); } } @@ -95,7 +91,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; - auto kernel_args = MakeUnique>(); + auto kernel_args = absl::make_unique>(); for (const BufferAllocation* arg : args_) { const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); kernel_args->add_device_memory_argument(buf); @@ -107,7 +103,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { - return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); + return InternalError("Unable to launch kernel %s", kernel_name_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index d751de50ad6671b3bf88cd4de49a8feb448e13ba..f63db5c3696f8f3bbd5956724240b2b06b4f1b98 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -47,7 +47,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice args, + KernelThunk(absl::Span args, const string& kernel_name, const HloInstruction* hlo_instruction, int unroll_factor); KernelThunk(const KernelThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index eb93efc560efbb4c14065ec98b980a1ca78605c6..698d2d51cc81a6c87f6578f1f35cdb47cf6bb4f2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -34,6 +34,9 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc index 12a8a59488bfdd6ce55f762926cd63ba56bf9d7f..85bc58cb445627695a46171db64cd8a1f10e0fc8 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -86,10 +86,11 @@ void IrDumpingPassManager::run(llvm::Module &module) { const llvm::PassInfo *PI = llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); const string basename = ReplaceFilenameExtension( - tensorflow::io::Basename(input_filename_), - tensorflow::strings::Printf( + absl::string_view(tensorflow::io::Basename(input_filename_)), + absl::StrFormat( "pass-%02d.before.%s.ll", i, - (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); + absl::string_view(PI == nullptr ? "unknown" + : PI->getPassArgument().data()))); llvm::legacy::PassManager::add( new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename))); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index cf44458a2ed6c069c1469bb975c62565264451c1..8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -20,13 +20,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -54,10 +56,7 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" @@ -107,8 +106,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, << ", " << compute_capability.second << ") ." << "Defaulting to libdevice for compute_" << libdevice_version; } - return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version, - ".10.bc"); + return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc"); } // Gets the GPU name as it's known to LLVM for a given compute capability. If @@ -138,15 +136,16 @@ static string GetSmName(std::pair compute_capability) { << "Defaulting to telling LLVM that we're compiling for sm_" << sm_version; } - return tensorflow::strings::StrCat("sm_", sm_version); + return absl::StrCat("sm_", sm_version); } // Convenience function for producing a name of a temporary compilation product // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, - tensorflow::StringPiece extension) { - return ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); + absl::string_view extension) { + return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( + llvm_ir::AsString(input_filename))), + extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -167,7 +166,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { // Returns the TargetMachine, given a triple. std::unique_ptr GetTargetMachine( - llvm::Triple triple, tensorflow::StringPiece cpu_name, + llvm::Triple triple, absl::string_view cpu_name, const HloModuleConfig& hlo_module_config) { std::string error; const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); @@ -180,7 +179,7 @@ std::unique_ptr GetTargetMachine( TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); llvm_ir::SetTargetOptions( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math(), + .xla_gpu_enable_fast_math(), &target_options); // Enable FMA synthesis. @@ -205,7 +204,7 @@ std::unique_ptr GetTargetMachine( default: codegen_opt_level = CodeGenOpt::None; } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional(RelocModel), Optional(CMModel), codegen_opt_level)); @@ -243,9 +242,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, } // Emits the given module to a bit code file. -void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { +void EmitBitcodeToFile(const Module& module, absl::string_view filename) { std::error_code error_code; - llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code, + llvm::ToolOutputFile outfile(string(filename).c_str(), error_code, llvm::sys::fs::F_None); if (error_code) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); @@ -266,8 +265,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-nvptx.dummy"), + ReplaceFilenameExtension( + absl::string_view(tensorflow::io::Basename(module_id)), + "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -332,8 +332,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { - return tensorflow::errors::Internal(tensorflow::strings::StrCat( - "Error linking libdevice from ", libdevice_path)); + return tensorflow::errors::Internal( + absl::StrCat("Error linking libdevice from ", libdevice_path)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 54e0e140dea1c3a8b21ffde2950c4bc9b703b71c..9654175bfafbb2521743e7894188abe5b5a15217 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index 9ef9bc3a50fc76f83f05e19163ab339f2da6ef3c..3b2c3591d95ee5a319c82336e9b500d14f88734f 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace { @@ -52,14 +52,13 @@ std::unique_ptr LoadIRModule(const string& filename, return module; } -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension) { +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension) { auto pos = filename.rfind('.'); - tensorflow::StringPiece stem = - pos == tensorflow::StringPiece::npos - ? filename - : tensorflow::StringPiece(filename.data(), pos); - return tensorflow::strings::StrCat(stem, ".", new_extension); + absl::string_view stem = pos == absl::string_view::npos + ? filename + : absl::string_view(filename.data(), pos); + return absl::StrCat(stem, ".", new_extension); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h index a6daeca95a6da66cb31b82805a6896f57cb80354..60f4926849cd3e8ad144f657f9feb3c3e1ea25e2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace llvm { class LLVMContext; @@ -41,8 +41,8 @@ std::unique_ptr LoadIRModule(const string& filename, // // For example: // ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension); +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c62bae0628f7b2fbfe822104fbe5f3528e0e09c3..c21f76f6eb1874bfa5a1d296c78ea0e3b9261eca 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +50,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // If possible, we want to pick a reduce operand of the fusion root, // because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*inst)) { return inst; } } @@ -63,7 +65,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, auto get_element_shape = [&](const HloInstruction* element_instr) { // Special handling of kReduce instructions -- the fusion // applies to the first operand. - if (element_instr->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*element_instr)) { return element_instr->operand(0)->shape(); } return element_instr->shape(); @@ -85,65 +87,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, get_element_shape(element_instr_1), get_element_shape(element_instr_2)); } -namespace { -bool IsInputFusibleReduction(HloInstruction* instr) { - if (instr->IsMultiOutputFusion()) { - for (const HloInstruction* operand : - instr->fused_expression_root()->operands()) { - if (operand->opcode() == HloOpcode::kReduce) { - CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) - << " Reduce multi-output fusion " << instr->ToString() - << " must be an input fusion."; - return true; - } - } - return false; - } else if (instr->opcode() == HloOpcode::kFusion) { - // The loop emitter can handle to-vector reduce fusions. Such reduce - // fusions have the fusion kind kLoop rather than kInput. We do not fuse - // to-vector reduce fusions, because the resulting fusions may no longer be - // supported by loop emitter. - return IsReductionToVector(*instr->fused_expression_root()); - } else { - return IsReductionToVector(*instr); - } -} - -// The code emitted for reduction suffers from poor data locality if the layouts -// of input parameters differ. In such situtations it is beneficial not to fuse. -// We consider input params with maximum rank only. Params with smaller ranks -// will be broadcasted and have not been observed to cause data locality issues. -// TODO(b/111977086): Improve reduce emitters to remove this limitation. -bool ReduceFriendlyInputLayouts(HloInstruction* instr) { - std::vector params; - if (instr->opcode() == HloOpcode::kFusion) { - params = instr->fused_parameters(); - } else { - for (HloInstruction* operand : instr->operands()) { - params.push_back(operand); - } - } - int64 max_rank = 0; - const Layout* max_rank_layout; - for (HloInstruction* param : params) { - if (ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); - max_rank_layout = ¶m->shape().layout(); - } - } - return c_all_of(params, [&](HloInstruction* param) { - return (ShapeUtil::Rank(param->shape()) < max_rank) || - (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); - }); -} - -} // namespace - bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - // We can fuse reduces and loop fusions. - return IsInputFusibleReduction(instr) || - (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop); + // We can fuse reduces and loop fusions. Elementwise instructions can be fused + // with any other instruction. + // TODO(b/112957171): This should use the same isFusible logic as + // instruction_fusion. + return instr->IsFusible() && + (IsInputFusibleReduction(*instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr->IsElementwise()); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -177,11 +130,12 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this - // improves things. + // improves things. Also disable fusing standalone input-fusible reduces into + // loop fusions. CHECK(instr1->opcode() == HloOpcode::kFusion); if ((instr2->opcode() == HloOpcode::kFusion && instr1->fusion_kind() != instr2->fusion_kind()) || - (instr2->opcode() != HloOpcode::kFusion && + (IsReductionToVector(*instr2) && instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { return false; } @@ -197,7 +151,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { tensorflow::gtl::FlatSet to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, - // then filter out instructions that will be no longer fusable because of + // then filter out instructions that will be no longer fusible because of // reachability change. This avoids recalculating reachability on a large set // of instructions. std::vector> @@ -212,8 +166,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << consumer->name() << " has no users."; continue; } - if (!IsInputFusibleReduction(consumer)) { - VLOG(3) << consumer->name() << " is not an input-fusable reduction."; + if (!IsInputFusibleReduction(*consumer)) { + VLOG(3) << consumer->name() << " is not an input-fusible reduction."; continue; } VLOG(3) << consumer->name() @@ -222,8 +176,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { auto consumer_operands = consumer->operands(); for (size_t i = 0; i < consumer_operands.size(); ++i) { HloInstruction* producer = consumer_operands[i]; - if (!producer->IsFusable()) { - VLOG(3) << producer->name() << " is not fusable."; + if (!producer->IsFusible()) { + VLOG(3) << producer->name() << " is not fusible."; continue; } const bool is_loop_fusion = @@ -237,7 +191,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " has an incompatible shape."; continue; } - if (!ReduceFriendlyInputLayouts(producer)) { + if (!LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { VLOG(3) << producer->name() << " has inputs with mixed layouts."; continue; } @@ -248,7 +202,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } // Do not fuse a producer if the other operands of the fusion are // reachable from the producer, this would create a cycle. - if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { @@ -263,12 +217,12 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } } - // Filter out pairs that will be no longer fusable because of reachability + // Filter out pairs that will be no longer fusible because of reachability // change. for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; HloInstruction* consumer = fusion_pair.second; - if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 67ca5d49eee8508e93284b134f8410eb3a89f9ce..f0b4d67ab8463a39161f71908746cad9e2a8670a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -22,7 +22,7 @@ namespace xla { namespace gpu { // Multi-output fusion of sibling and producer-consumer instructions for the -// Jellyfish backend. +// GPU backend. class GpuMultiOutputFusion : public MultiOutputFusion { public: GpuMultiOutputFusion(); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 14f157a5e518a0ec82c664c123629d04bd385bbf..8a6e5327e082791ff857a89e840c6a4f045f0edb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,19 +15,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { +namespace op = xla::testing::opcode_matchers; + using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( @@ -47,7 +47,7 @@ const char kModulePrefix[] = R"( TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) mul = f32[6400]{0} multiply(p1.1, p1.1) @@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { const.1 = f32[] constant(1) p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) @@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -256,8 +256,136 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { + // Fusing a reduce into a loop fusion would require changing the fusion kind. + // That's not supported yet. + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Exp(), op::Add())); +} + +TEST_F(MultiOutputFusionTest, + MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -277,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_add { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -304,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) @@ -345,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_element_wise { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -372,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) @@ -413,7 +541,7 @@ TEST_F(MultiOutputFusionTest, TEST_F(MultiOutputFusionTest, ProducerConsumerFusionReduceUnfriendlyLoopFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( mixed_input_layouts_computation { p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 8fa0439006b95b7a567d8fc8dbec36f193fb0e77..0b3b429710a1a3158ce57a393a09291c95a2ef7a 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -21,13 +21,15 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -38,13 +40,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -72,10 +75,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" -#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -84,7 +87,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -131,11 +133,16 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. +// +// It takes a compiler pointer, as passes may compile and execute HLOs on the +// fly for cuDNN verification or other purposes. Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + DeviceMemoryAllocator* device_allocator, + Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -151,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -168,6 +176,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); + pipeline.AddPass(); + pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -196,8 +206,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass(); @@ -205,13 +217,29 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // pairs that TupleSimplifier fixes. pipeline.AddPass(); } + // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add + // instructions which can be simplified by constant folding. + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { - HloPassPipeline pipeline("layout_assignment"); + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -246,8 +274,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass(stream_exec, - device_allocator); + pipeline.AddPass( + stream_exec, device_allocator, compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -257,17 +285,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(); + fusion.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); fusion.AddPass(); fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); + fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker(); + reduce_pipeline.AddInvariantChecker( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -281,14 +312,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } } - { - // Do an aggressive LICM pass over while loops. In particular, this hoists - // constants that were sunk by WhileLoopConstantSinking. Leaving them in - // the while loop may result in unnecessary copies. - HloPassPipeline pipeline("while-loop-licm"); - pipeline.AddPass(true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } return Status::OK(); } @@ -301,7 +324,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -351,9 +375,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { string vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, &vmin_str, &vdot_str) || - !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || - !tensorflow::strings::safe_strto64(vmin_str, &vmin) || - !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + !absl::SimpleAtoi(vmaj_str, &vmaj) || + !absl::SimpleAtoi(vmin_str, &vmin) || + !absl::SimpleAtoi(vdot_str, &vdot)) { LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path << " --version:\n" << out; @@ -380,7 +404,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot - << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " "miscompile XLA code, leading to incorrect results or " "invalid-address errors.\n\nYou do not need to update to CUDA " "9.2.88; cherry-picking the ptxas binary is sufficient."; @@ -465,7 +489,7 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, tensorflow::SubProcess ptxas_info_dumper; std::vector ptxas_args = { ptxas_path, ptx_path, "-o", cubin_path, - tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; + absl::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -501,11 +525,15 @@ NVPTXCompiler::NVPTXCompiler() StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { + // We dump the post-optimization HLO in RunBackend so no need to dump it here. + VLOG(2) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(2, module->ToString()); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec, device_allocator)); + OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); return std::move(module); } @@ -539,8 +567,8 @@ StatusOr> NVPTXCompiler::RunBackend( // must also be used to determine the thunk launch schedule. std::unique_ptr stream_assignment = AssignStreams(*module); TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - HloSchedule::Build(*module, *stream_assignment, pointer_size_)); + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -557,6 +585,7 @@ StatusOr> NVPTXCompiler::RunBackend( // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); + VLOG(2) << "*** HLO After Optimization"; XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); @@ -668,7 +697,7 @@ StatusOr> NVPTXCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); + ir_dump_directory, absl::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); @@ -684,7 +713,7 @@ StatusOr> NVPTXCompiler::RunBackend( const std::vector cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); - auto thunk_schedule = MakeUnique( + auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); VLOG(2) << "Printing the thunk schedule..."; @@ -698,7 +727,7 @@ StatusOr> NVPTXCompiler::RunBackend( cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = MakeUnique(*module); + profile_index_map = absl::make_unique(*module); profile_printer = CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } @@ -807,7 +836,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index d4d2909f1b2dc57c3ae0f9d67067e533574369dd..8e97774750344bfc141daa7d752300762c708613 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc index 4aaf0c9e142106a0e74f319d71dad4c4c96d3f08..2fa170964e974a6535307d7a21eb3e7760d02536 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h index a752eb70119b00e8cca7ddce26da7730ef5db8cb..160ba4b691f818ff01b41b8603c11853ea12c253 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -36,22 +36,19 @@ class OutfeedBuffer { OutfeedBuffer(int64 length) : length_(length) {} // Waits for the device transfer to be finished. - std::unique_ptr WaitUntilAvailable() { - done_.WaitForNotification(); - return std::move(destination_); - } + void WaitUntilAvailable() { done_.WaitForNotification(); } int64 length() const { return length_; } - void set_destination(std::unique_ptr destination) { + void set_destination(std::unique_ptr destination) { destination_ = std::move(destination); } - Literal* destination() { return destination_.get(); } + MutableBorrowingLiteral* destination() { return destination_.get(); } // Callback to signal that this buffer is consumed. void Done() { done_.Notify(); } private: - std::unique_ptr destination_; + std::unique_ptr destination_; const int64 length_; tensorflow::Notification done_; }; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 7986e63f43ee508370f94fdb9057b91bfe4add18..e0f3e84a4cb25792cf10d38fc529f3e638acf8e4 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -50,10 +50,6 @@ Status OutfeedThunk::ExecuteOnStream( if (!*buffer) { // Tuple pointers. return Status::OK(); } - // Allocate storage for the literal data. - const Shape& shape = - ShapeUtil::GetSubshape(outfeed_buffers->shape(), index); - (*buffer)->set_destination(Literal::CreateFromShape(shape)); BufferAllocation::Slice slice = outfeed_slices_.element(index); se::DeviceMemoryBase data_address; @@ -100,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Outfeeding from GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index 79f7d31816baf0b95b967771b956a9c06ac81e91..e3869b5c368957571219a39600214140022a7318 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -17,14 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace gpu { -using tensorflow::gtl::ArraySlice; - // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. static constexpr int64 kDesiredNumFeaturesFactor = 8; @@ -38,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8; // there's additional room for speedups. Achieving those speedups without also // slowing other things down will likely require a more sophisticated heuristic, // possibly some form of auto-tuning. -static constexpr double kMaxBytesTouchedIncrease = 1.2; +// +// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4" +// special case inside PadShape won't fire. +static constexpr double kMaxBytesTouchedIncrease = 1.35; // Pads the given dimensions in the given shape up to a multiple of // kDesiredNumFeaturesFactor. -static Shape PadShape(Shape s, ArraySlice dims) { +static Shape PadShape(Shape s, absl::Span dims) { for (int64 dim : dims) { int64 dim_to_pad_size = s.dimensions(dim); - int64 new_dim_to_pad_size = - RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + + // Round dim_to_pad_size up to the next multiple of + // kDesiredNumFeaturesFactor. + // + // Special case: dims of size 3 are rounded up to 4, not + // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia), + // this helps, but as of writing, it's not supported by anything in the + // cudnn docs. + int64 new_dim_to_pad_size; + if (dim_to_pad_size == 3) { + new_dim_to_pad_size = 4; + } else { + new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + } + s.set_dimensions(dim, new_dim_to_pad_size); } return s; @@ -64,8 +80,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); @@ -211,7 +227,11 @@ static std::vector GetRelevantConvs(HloComputation* comp) { std::vector convs; for (HloInstruction* instr : comp->instructions()) { if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16) { + instr->operand(0)->shape().element_type() == F16 && + // TODO(timshen): Disable for fused conv for now. Implement it if it's + // needed. + Cast(instr)->custom_call_target() != + kCudnnConvBiasActivationForwardCallTarget) { convs.push_back(instr); } } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 192359f026bfb2f1d5436713e4a30725fa0ad6ba..e592a3774ec28605fda912298c74ca7976ff99ac 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -30,11 +30,9 @@ namespace gpu { // targeting before running this pass. // // TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloPassInterface { +class PadForTensorCores : public HloModulePass { public: - tensorflow::StringPiece name() const override { - return "pad for tensor cores"; - } + absl::string_view name() const override { return "pad for tensor cores"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc index 99e7580b826fc5cd6d98a037a5eb064552952e18..5c92b0dcb873b873074704dca8f27d4067b070df 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -using PadForTensorCoresTest = HloVerifiedTestBase; +class PadForTensorCoresTest : public HloVerifiedTestBase {}; TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ParseAndVerifyModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b22040eee167e784bed58dbc0d0ad2ae042037f3..b42a19e3a2200e917f8040be183b8d79c9e4e161 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -29,7 +30,8 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); + CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || + conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -67,9 +69,8 @@ HloInstruction* MaybePaddedAndSlicedInput( conv_window.dimensions(i).base_dilation() - 1); } PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -124,9 +125,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloComputation* computation = kernel->parent(); PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -162,12 +162,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. - Shape old_conv_shape = conv->shape().tuple_shapes(0); - VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, - new_conv_window, - conv->convolution_dimension_numbers()); + std::vector operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; + auto new_conv = conv->parent()->AddInstruction( + conv->CloneWithNewOperands(conv->shape(), operands)); + new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -235,18 +237,18 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. - Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); - HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( - backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums); + HloInstruction* new_backward_conv = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + backward_conv->shape(), {padded_input, output})); + new_backward_conv->set_window(new_backward_conv_window); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -309,9 +311,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( - new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums); + HloInstruction* new_backward_conv_call = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + ShapeUtil::MakeTupleShape( + {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), + {output, filter})); + new_backward_conv_call->set_window(new_backward_conv_window); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. @@ -381,7 +386,8 @@ StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { } for (HloInstruction* instruction : convs) { const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget) { changed |= CanonicalizeForwardConvolution(instruction); } else if (target == kCudnnConvBackwardFilterCallTarget) { changed |= CanonicalizeBackwardFilterConvolution(instruction); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 67e51509e4c717951c83c7e41943af1de762dee0..25cdf64c4cf01300869044d3e4d7c34c85626a5a 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -24,9 +24,9 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloPassInterface { +class PadInsertion : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "pad insertion"; } + absl::string_view name() const override { return "pad insertion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 3838fee674566196e10ddd98462c1a1aa7835e1a..8154d75d23a6d49153ccb6824402aff73f365617 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) : LoopEmitter(target_element_generator, target_arrays, b), @@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( unroll_factor_(unroll_factor) {} std::vector -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index b82a23419df08cafdc69b6d2f14528484b95dc73..f32ea1ce4c4192f39851a6441c46663df3063724 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -47,18 +47,17 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // // This is used in multi-output fusion. target_element_generator should // produce a struct with N elements, one for each of target_arrays. - ParallelLoopEmitter( - const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, - int unroll_factor = 1); + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + absl::Span target_arrays, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* b, int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d3fd0544fb68809125e9b9f7a5e5b7eff8c6ef43..cf9f102d31305da15dabaf6247f23c5ca9a9e054 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -34,9 +34,8 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << tensorflow::strings::Printf("[block: %lld, thread: %lld]", - launch_dims.block_count(), - launch_dims.threads_per_block()); + out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), + launch_dims.threads_per_block()); return out; } @@ -91,9 +90,9 @@ LaunchDimensions CalculateLaunchDimensions( } int64 block_count = CeilOfRatio(num_elements, threads_per_block); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " - "block) = ceil(%lld/%lld) = %lld", + "block) = ceil(%d/%d) = %d", num_elements, threads_per_block, block_count); return LaunchDimensions(block_count, threads_per_block); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 0806dd51614f4d2da12f3fbbc9fb98df5273d5c8..5b6cf2c04d05378a363232e33a6df6432cd6848e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -119,7 +119,7 @@ int ComputeStreamToAssign( } // namespace std::unique_ptr AssignStreams(const HloModule& module) { - auto stream_assignment = MakeUnique(); + auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = computation.ComputeReachability(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3..c4f43cc9a614283acb376b5f98e4976615b590ad 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -15,25 +15,27 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloTestBase { +class StreamAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr CreateNewModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } // Pre-canned shapes. @@ -48,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -67,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -97,26 +99,26 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 05b305ea4cdfdbaeb42544b626a6b9990bb42f57..08ff52211af163fec39646ca6bf14da9d1b815e4 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { @@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, input_layout.push_back(dnums.input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid input layout: ", - DataLayoutString(input)); + return InternalError("Invalid input layout %s for conv with dnums %s", + DataLayoutString(input), + ConvolutionDimensionNumbersToString(dnums)); } std::vector filter_layout; @@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, filter_layout.push_back(dnums.kernel_input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid filter layout: ", - FilterLayoutString(filter)); + return InternalError("Invalid filter layout %s for conv with dnums %s", + FilterLayoutString(filter), + ConvolutionDimensionNumbersToString(dnums)); } std::vector output_layout; @@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, output_layout.push_back(dnums.output_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid output layout: ", - DataLayoutString(output)); + return InternalError("Invalid output layout %s for conv with dnums %s", + DataLayoutString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), @@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(input, nhwc_input)) { input_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid input layout: ", - input.ShortDebugString()); + return InternalError("Invalid input layout %s for conv with dnums %s", + LayoutUtil::HumanString(input), + ConvolutionDimensionNumbersToString(dnums)); } FilterLayout filter_layout; @@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(filter, nhwc_filter)) { filter_layout = FilterLayout::kOutputYXInput; } else { - return tensorflow::errors::Internal("Invalid filter layout: ", - filter.ShortDebugString()); + return InternalError("Invalid filter layout %s for conv with dnums %s", + LayoutUtil::HumanString(filter), + ConvolutionDimensionNumbersToString(dnums)); } DataLayout output_layout; @@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(output, nhwc_output)) { output_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid output layout: ", - output.ShortDebugString()); + return InternalError("Invalid output layout %s for conv with dnums %s", + LayoutUtil::HumanString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(input_layout, filter_layout, output_layout); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 4fad3f46cf953945e4f395e751e5ba76db97ecc4..a7255335672a3622d122e9fc5ebfab236a5ba895 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -25,32 +25,32 @@ filegroup( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "gpu_codegen_test", testonly = True, srcs = ["gpu_codegen_test.cc"], hdrs = ["gpu_codegen_test.h"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -60,15 +60,14 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/core:test_main", @@ -78,9 +77,7 @@ tf_cc_test( tf_cc_test( name = "gpu_index_test", srcs = ["gpu_index_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -94,15 +91,14 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "gpu_infeed_test", srcs = ["infeed_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -123,9 +119,7 @@ tf_cc_test( tf_cc_test( name = "gpu_kernel_tiling_test", srcs = ["gpu_kernel_tiling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo", @@ -140,7 +134,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ldg_test", srcs = ["gpu_ldg_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -150,15 +144,14 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "gpu_noalias_test", srcs = ["gpu_noalias_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -168,15 +161,14 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "gpu_fusion_test", srcs = ["gpu_fusion_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -190,9 +182,7 @@ tf_cc_test( tf_cc_test( name = "gpu_unrolling_test", srcs = ["gpu_unrolling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -207,9 +197,7 @@ tf_cc_test( name = "gpu_alignment_test", testonly = True, srcs = ["gpu_alignment_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -221,3 +209,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cudnn_fused_convolution_rewriter_test", + srcs = ["cudnn_fused_convolution_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5632cac1862e21825888d94ab1eee5e1c9fd6800 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc @@ -0,0 +1,283 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class CudnnFusedConvolutionRewriterTest : public HloTestBase { + protected: + string GetOptimizedHlo(absl::string_view hlo_string) { + return backend() + .compiler() + ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + } + + void TestMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_EQ(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convForward")) + << optimized_hlo_string; + EXPECT_NE(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convBiasActivationForward")) + << optimized_hlo_string; + EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) + << optimized_hlo_string; + } + } + + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); + EXPECT_NE(absl::string_view::npos, + optimized_hlo.find("__cudnn$convForward")) + << optimized_hlo; + EXPECT_EQ(absl::string_view::npos, + optimized_hlo.find("__cudnn$convBiasActivationForward")) + << optimized_hlo; + } + } +}; + +TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { + // max(0, conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { + // max(0, conv(x, w) + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { + // max(0, conv(x, w) + side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { + // max(0, conv(x, w) + side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { + // max(0, 0.999994934 * conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { + // max(0, conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, + TestScaledConvAndScaledSideInputWithBias) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { + // max(0.1, conv(x, w)) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + point_one = TYPE[] constant(0.1) + point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { + // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input1 = TYPE[1,3,3,64] parameter(2) + side_input2 = TYPE[1,3,3,64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input2) + add2 = TYPE[1,3,3,64] add(add1, side_input1) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 4b8415fe9106137e588f345a3492f93e46aeb5b6..79e77d4c4d649020cf52ac25c220c3f90e8469b9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -32,15 +32,14 @@ std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { debug_options.add_xla_disable_hlo_passes("constant_folding"); config.set_debug_options(debug_options); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); - string ptx_str = - std::string(static_cast(executable.get())->ptx()); + string ptx_str(static_cast(executable.get())->ptx()); StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index ce69e058e64aab1f3c292b2ad7c7b529d4666b35..780539c164277f14c2bd964024f7c3ca179f4ada 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {}; TEST_F(GpuCopyTest, UseMemcpy) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index e5958165eff21d82faf821213e50fe30a11059a4..a06576df7b874745236a8d9075355a01ec42e777 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index cca35316f0c472d2a17c466f8cd1af7f22575a8b..15d1e269cc22b88f5269175084f20600f165011c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -27,13 +27,22 @@ namespace { class GpuKernelTilingTest : public GpuCodegenTest { protected: - GpuKernelTilingTest() { + GpuKernelTilingTest() {} + + // Most tests in this file want to skip layout assignment, but a few need it + // enabled. + HloModuleConfig ConfigWithLayoutAssignment() { + return GetModuleConfigForTest(); + } + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - config_.set_debug_options(debug_options); // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout_assignment"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; } - HloModuleConfig config_; }; TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { @@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // + // We must enable layout assignment in order for this test to work correctly. + // AlgebraicSimplifier removes copy1; it's added back by layout assignment, + // which respects the module's entry computation layout. But if we don't run + // layout assignment...well, nobody else adds the copy back. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) })"; - // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // Check that a call to llvm.nvvm.barrier0 is not generated. As in + // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment + // here. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest, })"; // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6c9ae7bada5e7545b558b6fcb872ece60850cbe9..6a9ecd9dae7c9ddde0b56d8615e4a39fb3df0af9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index c42e5704a4d2e611a203293e60a86ba4104bca46..15198865bda98f9718342d5a444a20305f923b48 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 962293630683fcbbce3941f622061a2ff0f02dda..0f2d5568cafc9db0f5f067437fdd5e2e775ad2c8 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + // Disable layout assignment for this test. Layout assignment does not expect + // fusions to be present, and so it does the wrong thing. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); config.set_debug_options(debug_options); const char *const kMultiOutputFusionModule = R"( diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index 9072b30317d253fd6d50e9d98949cad4eaebfe7b..f8120a5fa00ce38644cd85c54d5ef65701be1eda 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } @@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) { TEST_F(InfeedTest, LargeInfeed) { Array4D array(80, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); + TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D(array)); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests that a large tuple infeed can be handled. TEST_F(InfeedTest, SingleInfeedLargeTuple) { Array4D array(40, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR4FromArray4D(array).get(), - LiteralUtil::CreateR0(5).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR4FromArray4D(array), + LiteralUtil::CreateR0(5)})); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 4df0bb005b623e5ac79a4dfcb7c5a8a7a400940c..e68bee035a029178844282995429eaa960cc4817 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -82,17 +82,9 @@ class Thunk { return Status::OK(); } - // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) - // before calling ExecuteOnStream(stream). If it returns true, it's the - // user's responsibility to wait for all activity on the GPU to finish before - // calling ExecuteOnStream. - // - // This value is not required to be constant for a given Thunk. For example, - // a Thunk that performs autotuning may return true for its first run and - // false thereafter. - virtual bool ShouldHaltAllActivityBeforeRunning(se::Stream* /*stream*/) { - return false; - } + // Returns true if this kernel will autotune for the stream device the next + // time it is run. + virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; } // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index bdb062837c5ba4b588ea0d535a786f33fe4f4015..141f3219387940a08ef22cbcc0be0971a14c2cd6 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -144,16 +144,15 @@ const std::list& ThunkSchedule::DependsOn( string ThunkSchedule::ToString() const { string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - tensorflow::strings::StrAppend(&result, "\t", - thunk->hlo_instruction()->ToString(), "\n"); + absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); } - tensorflow::strings::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "Dependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { - tensorflow::strings::StrAppend( - &result, "\t", dependent->hlo_instruction()->name(), " depends on ", - dependency->hlo_instruction()->name(), "\n"); + absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(), + " depends on ", dependency->hlo_instruction()->name(), + "\n"); } } return result; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index a10e40451c1db01ce73db7b56a3a0599769fa49b..989b542ff4503600b2e3c751a23345959fab6fd6 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" @@ -24,24 +25,32 @@ namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - std::vector tuple_element_buffer_addresses; - for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - tuple_element_buffer_addresses.push_back( - buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); + auto size = tuple_element_buffers_.size(); + auto tuple_element_buffer_addresses = absl::make_unique(size); + for (int i = 0; i != size; ++i) { + tuple_element_buffer_addresses[i] = + buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); } se::DeviceMemory dest_buffer_address( buffer_allocations.GetDeviceAddress(dest_buffer_)); - auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + auto host_size = size * sizeof(void*); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream ->ThenMemcpy(&dest_buffer_address, - tuple_element_buffer_addresses.data(), host_size) + tuple_element_buffer_addresses.get(), host_size) .ok()) { return InternalError( "Unable to launch MemcpyH2D from %p to %p with size %lu", - tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), - sizeof(void*) * tuple_element_buffer_addresses.size()); + tuple_element_buffer_addresses.get(), dest_buffer_address.opaque(), + host_size); + } + // Free the tuple address buffer when memcpy is done. + auto* buffers_raw = tuple_element_buffer_addresses.release(); + if (!stream->ThenDoHostCallback([buffers_raw] { delete[] buffers_raw; }) + .ok()) { + delete[] buffers_raw; + return InternalError("Unable to enqueue host callback!"); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 2d5735d6c40ccd26f0e527f1a02403910db4c812..dcdbf2cf3c2aa87cc11a3473a765cb405b50e2a6 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -34,8 +34,7 @@ namespace gpu { // issue (b/31336476). class TupleThunk : public Thunk { public: - TupleThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, + TupleThunk(absl::Span tuple_element_buffers, const BufferAllocation::Slice& dest_buffer, const HloInstruction* hlo_instruction) : Thunk(Kind::kTuple, hlo_instruction), diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index d81d87e7dc54cd752000b85f3ec173d66d7195e4..c4754fe378960834e1157b0ff25c03c0fc4754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -34,9 +34,9 @@ WhileThunk::WhileThunk( // and body_thunk_sequence_ constructors because these SequentialThunks // are logically "part of" this WhileThunk, and shouldn't be profiled // separately from it. - condition_thunk_sequence_(MakeUnique( + condition_thunk_sequence_(absl::make_unique( std::move(*condition_thunk_sequence), nullptr)), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, @@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", stream, - block_status.error_message().c_str()); + block_status.error_message()); } if (!condition_result) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc deleted file mode 100644 index c5321df6c466fcb3816fb2aedad65b7c3811cb37..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - -#include -#include - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -namespace { - -// TODO(b/33483676) Use an expression tree to specify computations to pattern -// match for while transformations. - -// ExprTree is a simple recursive data structure used to express computation -// patterns to match. -// -// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and -// HloOpcode of the operand. -// -// For example, the following computation: -// -// Parameter -// | -// Const GetTupleElement -// \ / -// Add (root) -// -// Can be matched with the following expression tree: -// -// ExprTree add(HloOpcode::kAdd, -// ExprTree(HloOpcode::kConstant), -// ExprTree(HloOpcode::kGetTupleElement, -// tuple_index, ExprTree(HloOpcode::kParameter))); -// -// Match the ExprTree root against an Hlo graph: -// -// ExprTree::TaggedInstructionMap tagged_instructions; -// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), -// &tagged_instructions)); -// -// Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further processing (i.e. parsing -// constants or recording the tuple_index). -// -class ExprTree { - public: - explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} - ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} - ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) - : opcode_(opcode) { - SetOperand(index0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, - int64 index1, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(index0, operand0); - SetOperand(index1, operand1); - } - ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) - : opcode_(opcode), tag_(tag) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(0, operand0); - SetOperand(1, operand1); - } - - ExprTree(const ExprTree& to_copy) { - opcode_ = to_copy.opcode_; - tag_ = to_copy.tag_; - if (to_copy.fused_root_tree_ != nullptr) { - fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); - } - for (auto& pair : to_copy.operands_) { - CHECK(operands_.find(pair.first) == operands_.end()); - operands_.insert(std::make_pair( - pair.first, std::unique_ptr(new ExprTree(*pair.second)))); - } - } - - void SetFusedRoot(const ExprTree& fused_root) { - fused_root_tree_.reset(new ExprTree(fused_root)); - } - - typedef std::unordered_map - TaggedInstructionMap; - - // Matches 'instruction' HloOpcode against 'opcode_'. - // Recursively matches each operand in 'operands_'. - // Recursively matches fused instructions starting at 'fused_root_tree_' - // if 'opcode_ == kFusion'. - // Returns OK status, and instructions in 'tagged_instructions' for each - // matched ExprTree node with a non-empty 'tag_'. - // Returns error message on failure. - Status Match(const HloInstruction* instruction, - TaggedInstructionMap* tagged_instructions) const { - if (opcode_ != instruction->opcode()) { - return InvalidArgument("got opcode %s, want %s", - HloOpcodeString(instruction->opcode()).c_str(), - HloOpcodeString(opcode_).c_str()); - } - - VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; - if (!tag_.empty()) { - tagged_instructions->insert({tag_, instruction}); - } - - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK(fused_root_tree_ != nullptr); - // Match fused instructions for this node starting a 'fused_root_tree'. - TF_RETURN_IF_ERROR(fused_root_tree_->Match( - instruction->fused_expression_root(), tagged_instructions)); - } - - // Match each operand in 'operands_'. - for (auto& pair : operands_) { - TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), - tagged_instructions)); - } - return Status::OK(); - } - - private: - void SetOperand(int64 index, const ExprTree& operand) { - CHECK_EQ(0, operands_.count(index)); - operands_.insert(std::make_pair(index, MakeUnique(operand))); - } - - HloOpcode opcode_; - std::unordered_map> operands_; - std::unique_ptr fused_root_tree_; - string tag_; -}; - -// MatcherBase is a base class that provides common functionality for -// sub-classes which match specific target sub-computations (i.e. loop -// induction variable initialization, comparison and update). -class MatcherBase { - public: - MatcherBase() {} - virtual ~MatcherBase() {} - - // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first successful match, error status otherwise. - virtual Status Run() { - Status status; - for (const ExprTree& expr_tree : expr_trees_) { - status = MatchExprTree(expr_tree); - if (status.ok()) { - return status; - } - } - return status; - } - - virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; - - // Returns the constant value parsed form kConstant 'instruction'. - // Returns error status otherwise. - Status ParseConstInteger(const HloInstruction* instruction, - int64* const_value) const { - CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); - PrimitiveType element_type = instruction->shape().element_type(); - if (element_type != S32 && element_type != S64) { - return InvalidArgument("Expected constant of integral type."); - } - const Literal& literal = instruction->literal(); - PrimitiveType type = literal.shape().element_type(); - if (type != S32 && type != S64) { - return InvalidArgument("Must use S32 or S64 integral types."); - } - if (type == S32) { - *const_value = static_cast(literal.GetFirstElement()); - } else if (type == S64) { - *const_value = literal.GetFirstElement(); - } - return Status::OK(); - } - - StatusOr GetTaggedInstruction( - const string& tag, - const ExprTree::TaggedInstructionMap& tagged_instructions) { - auto it = tagged_instructions.find(tag); - if (it == tagged_instructions.end()) { - return InvalidArgument("Cound not find instruction for tag: %s", - tag.c_str()); - } - return it->second; - } - - protected: - std::vector expr_trees_; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); -}; - -// WhileConditionComputationMatcher attempts to match a target computation -// pattern in the while condition sub-computation. -// If the target pattern is matched, two pieces of information are extracted -// from 'tagged' instructions returned by the matcher: -// -// *) 'tuple_index': -// *) The loop induction variable tuple_index from the GetTupleElement -// instruction of the matched computation. -// *) Used in subsequent matching passes of while init operand and body -// computations to select loop induction variable tuple element. -// -// *) 'loop_limit': -// *) The integral value from Constant root operand in matched computation. -// *) Used as the constant for the loop limit. -// -class WhileConditionComputationMatcher : public MatcherBase { - public: - explicit WhileConditionComputationMatcher(const HloComputation* computation) - : computation_(computation) { - expr_trees_.emplace_back(BuildCondExprTree()); - } - - int64 loop_limit() const { return loop_limit_; } - int64 tuple_index() const { return tuple_index_; } - - private: - // Builds expression tree for the following condition computation: - // - // Const Parameter - // \ / - // Fusion ------------> FusionParam FusionParam - // \ / - // GTE / - // \ / - // LessThan (fused root) - // - ExprTree BuildCondExprTree() { - // Build ExprTree for fused instructions. - ExprTree fused_root( - HloOpcode::kLt, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), - ExprTree(HloOpcode::kParameter)); - - // Build top-level computation. - ExprTree root(HloOpcode::kFusion, - ExprTree(HloOpcode::kConstant, "loop_limit"), - ExprTree(HloOpcode::kParameter, "param0")); - - root.SetFusedRoot(fused_root); - return root; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while condition"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - // Get tagged GTE instruction and set 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - tuple_index_ = gte->tuple_index(); - - // Get tagged Constant instruction and parse 'loop_limit_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_limit", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); - - // Get tagged "param0" instruction, and check that it matches - // 'computation_' parameter 0. - TF_ASSIGN_OR_RETURN(const HloInstruction* param0, - GetTaggedInstruction("param0", tagged_instructions)); - if (param0 != computation_->parameter_instruction(0)) { - return InvalidArgument("Unexpected Parameter0 instruction : %s", - param0->name().c_str()); - } - - // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, - // and compare it to 'computation_' parameter0. - TF_ASSIGN_OR_RETURN( - const HloInstruction* gte_fusion_param0, - GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); - CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); - CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->parent()->FusionInstruction()->operand( - gte_fusion_param0->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - gte_fusion_param0->name().c_str()); - } - - return Status::OK(); - } - - const HloComputation* computation_; - - int64 loop_limit_ = -1; - int64 tuple_index_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); -}; - -// WhileInitOperandMatcher matches a target computation pattern of the -// while instructions 'init' operand, indexing the tuple at 'tuple_index'. -// On success, parses constant 'loop_start' which represents the loop induction -// variable start values, then returns OK. -// Returns error status otherwise. -class WhileInitOperandMatcher : public MatcherBase { - public: - WhileInitOperandMatcher(const HloInstruction* while_hlo, - const int64 tuple_index) - : while_hlo_(while_hlo), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildInitExprTree()); - } - - int64 loop_start() const { return loop_start_; } - - private: - // Builds expression tree for the following while init operand subcomputation: - // - // Const - // | - // Copy - // | - // Tuple0 - // | - // While - // - ExprTree BuildInitExprTree() { - return ExprTree( - HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kConstant, "loop_start")))); - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while init"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); - - // Get tagged while instruction check against 'while_hlo_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, - GetTaggedInstruction("while", tagged_instructions)); - if (while_hlo != while_hlo_) { - return InvalidArgument("Expected While for instruction : %s", - while_hlo->name().c_str()); - } - - // Get tagged Constant instruction and parse 'loop_start_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_start", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - - return Status::OK(); - } - - const HloInstruction* while_hlo_; - const int64 tuple_index_; - - int64 loop_start_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); -}; - -// WhileBodyComputationMatcher matches a target computation pattern for -// the loop induction variable update. Matching proceeds from the while body -// computation root[tuple_index] to param[tuple_index], where 'tuple_index' -// If the target pattern is matched, parses a constant which represents the -// loop induction variable increment value, then returns status OK. -// Returns error status otherwise. -class WhileBodyComputationMatcher : public MatcherBase { - public: - WhileBodyComputationMatcher(const HloComputation* computation, - const int64 tuple_index) - : computation_(computation), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); - expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); - } - - int64 loop_increment() const { return loop_increment_; } - - private: - // Builds expression tree for the following while body computation: - // - // - // FusionParam FusionParam - // \ / - // Const Param \ GTE1 - // \ / \ / - // Fusion -----------> Add - // | - // Copy - // | - // Tuple0 - // - ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { - // Build ExprTree for fused instructions. - ExprTree gte1 = - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); - ExprTree fused_root(HloOpcode::kAdd, const_index, - ExprTree(HloOpcode::kParameter), gte_index, gte1); - - // Build fusion instruction (and set fused root). - ExprTree fusion(HloOpcode::kFusion, 0, - ExprTree(HloOpcode::kConstant, "loop_increment"), 1, - ExprTree(HloOpcode::kParameter, "param0")); - fusion.SetFusedRoot(fused_root); - - // Build top-level computation. - ExprTree tuple0(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, fusion)); - return tuple0; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while body"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - for (const auto& pair : tagged_instructions) { - const auto& tag = pair.first; - const auto& inst = pair.second; - - if (tag == "gte" && inst->tuple_index() != tuple_index_) { - // Check that the matched GTE instruction is at the 'tuple_index' we - // matched in the while condition computation. - return InvalidArgument("Unexpected tuple index instruction : %s", - inst->name().c_str()); - } else if (tag == "loop_increment") { - // ParseHloString the constant which represents the loop induction - // variable increment value. - TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); - } else if (tag == "param0" && - inst != computation_->parameter_instruction(0)) { - // Check that the matched parameter == parameter 0 from 'computation_'. - return InvalidArgument("Unexpected Parameter0 instruction : %s", - inst->name().c_str()); - } else if (tag == "gte.fusion_param.param0") { - // Fusion parameter: lookup and compare with associated fusion operand. - CHECK_EQ(HloOpcode::kParameter, inst->opcode()); - CHECK(inst->IsFused()); - if (inst->parent()->FusionInstruction()->operand( - inst->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - inst->name().c_str()); - } - } - } - return Status::OK(); - } - - const HloComputation* computation_; - const int64 tuple_index_; - - int64 loop_increment_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); -}; - -} // namespace - -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo) { - if (while_hlo->opcode() != HloOpcode::kWhile) { - return InvalidArgument("Expected While instruction."); - } - - WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); - TF_RETURN_IF_ERROR(cond_matcher.Run()); - - WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(init_matcher.Run()); - - WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), - cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(body_matcher.Run()); - - // Check for valid For loop parameters. - if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { - return InvalidArgument("Loop start must be less than loop limit."); - } - if (body_matcher.loop_increment() <= 0) { - return InvalidArgument("Loop increment must greater than zero."); - } - return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), - body_matcher.loop_increment()); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h deleted file mode 100644 index fe3a954e1828ee4a323872eea81f64c7e780ad24..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Runs an analysis of the while loop instruction 'while_hlo' (and its -// associated sub-computations) to determine if it can be transformed into an -// equivalent "for" loop with the following "for" loop parameters: -// -// *) 'loop_start': loop induction variable starting value. -// *) 'loop_limit': loop induction variable limit value. -// *) 'loop_increment': loop induction variable per-iteration increment value. -// -// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. -// The values in the returned tuple are values extracted from the 'while_hlo' -// operand (and its sub-computations) during analysis. -// Returns an error status on failure. -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index dbc8442ed2785a112b674632689256c01282156b..9a61f8ac5a62e38e687a93890eb33481a01d51c8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -27,9 +26,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::Eq; -using ::testing::HasSubstr; - class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() @@ -110,16 +106,17 @@ class WhileTransformerTest : public HloTestBase { void RunFusionPasses() { // Run standard fusion passes. - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .status()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .status()); } void RunCopyInsertionPass() { - HloVerifier verifier; + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); @@ -141,10 +138,7 @@ class WhileTransformerTest : public HloTestBase { Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -153,18 +147,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -173,19 +162,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { - // Build computation with invalid loop limit. +TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { + // Build computation with an impossible loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); @@ -193,17 +177,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop start must be less than loop limit.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(0, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -212,11 +192,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop increment must greater than zero.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_FALSE(result); } } // namespace diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index aa89567ee86e59e197045c0b51eed3b9aa59fef7..ef70b688778df5115e2b5fe572d253a6948d076f 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,9 +22,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -43,8 +43,7 @@ namespace { // Adds a computation to the given HLO module which adds a scalar constant to // its parameter and returns the result. HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = - HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); + auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( @@ -84,7 +83,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // the module. std::unique_ptr MakeBigGraph() { HloModuleConfig config; - auto module = MakeUnique("BigGraph", config); + auto module = absl::make_unique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); @@ -113,8 +112,11 @@ std::unique_ptr MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4005fc0d114a3ec7a38dfb5edecdaeb1e8497ade..2bd04259c0e8193e6fde415df17a8232c701dec4 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" @@ -28,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -45,37 +46,37 @@ StatusOr HeapSimulator::MinimumMemoryForModule( // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); + HeapSimulator::Run(absl::make_unique(), *module, + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Run(absl::make_unique(), + computation, sequence, points_to_analysis, + size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -84,13 +85,13 @@ StatusOr HeapSimulator::Run( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -100,7 +101,7 @@ StatusOr HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -131,7 +132,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -142,7 +144,7 @@ Status HeapSimulator::RunComputation( } } else { // A GetTupleElement doesn't need to keep all of its operand's buffers - // alive. It only needs the buffers that relate to the element its + // alive. It only needs the buffers that relate to the element it's // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. for (const BufferValue* buffer : points_to.element({})) { @@ -164,7 +166,8 @@ Status HeapSimulator::RunComputation( std::vector dead_buffers_to_free; std::vector operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -275,22 +278,22 @@ Status HeapSimulator::RunComputation( *memory_by_computation_); } - // If the whole module is sequential, we can save memory by running the - // heap-simulation for sub-computations inline. E.g. the buffers for the - // condition and body of a kWhile instruction are only live for the duration - // of the instruction itself. + // If all computations in the module have been scheduled, we can save memory + // by running the heap-simulation for sub-computations inline. E.g. the + // buffers for the condition and body of a kWhile instruction are only live + // for the duration of the instruction itself. // // The order that the sub-computations are simulated does not affect - // correctness; since the whole module is sequential, we know that the + // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -341,16 +344,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, + const HloSchedule* schedule, const tensorflow::gtl::FlatMap* memory_by_computation) - : no_fragmentation_stats_(MakeUnique()), + : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} @@ -378,9 +381,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - algorithm_->Alloc(buffer, size); - no_fragmentation_stats_->Alloc(buffer, size); - + const HloInstruction* instruction_to_calc_aliasing = + memory_by_computation_ == nullptr ? nullptr : instruction; + algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); + no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -518,6 +522,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + if (instruction == nullptr || + (instruction->opcode() != HloOpcode::kWhile && + instruction->opcode() != HloOpcode::kCall && + instruction->opcode() != HloOpcode::kConditional)) { + Alloc(buffer, size); + } +} + void NoFragmentationStatsHeap::AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& @@ -720,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() { return result_; } +void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + result_.chunk_map.emplace(buffer, Chunk{0, 0}); + return; + } + auto emplace_result = buffer_intervals_.emplace( + buffer, BufferInterval{buffer, size, current_time_, -1}); + DCHECK(emplace_result.second); + ++current_time_; +} + +void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + return; + } + BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer); + DCHECK_EQ(buffer_interval.buffer, buffer); + DCHECK_EQ(buffer_interval.size, size); + DCHECK_EQ(buffer_interval.end, -1); + buffer_interval.end = current_time_; + ++current_time_; +} + +namespace { + +// Node in BufferIntervalTree that stores the alloc and free times of a buffer, +// and the chunk assigned to it. +struct BufferIntervalTreeNode { + // Alloc time. + int64 start; + // Free time. + int64 end; + // Maximum free time of all nodes in the subtree where this node is the root. + int64 subtree_end; + // Allocated chunk for the buffer. + HeapSimulator::Chunk chunk; + // Left child. + BufferIntervalTreeNode* left; + // Right child. + BufferIntervalTreeNode* right; +}; + +// An interval tree that can query buffers overlapping in time. +class BufferIntervalTree { + public: + explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {} + + using Chunk = HeapSimulator::Chunk; + + // Adds a buffer to the interval tree, with the time interval and allocated + // chunk specified. + void Add(int64 start, int64 end, const Chunk& chunk) { + int index = node_count_; + DCHECK_LT(index, node_storage_.size()); + ++node_count_; + + node_storage_[index] = + BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr}; + + if (index == 0) { + // This is root. + return; + } + + BufferIntervalTreeNode* parent = &node_storage_[0]; + while (true) { + parent->subtree_end = std::max(parent->subtree_end, end); + if (parent->start > start) { + if (parent->left == nullptr) { + parent->left = &node_storage_[index]; + return; + } + parent = parent->left; + } else { + if (parent->right == nullptr) { + parent->right = &node_storage_[index]; + return; + } + parent = parent->right; + } + } + } + + // Returns vector of allocated chunks that overlap with the given time + // interval. + std::vector ChunksOverlappingInTime(int64 start, int64 end) { + std::vector result; + if (node_count_ == 0) { + return result; + } + std::vector visiting_stack; + visiting_stack.push_back(&node_storage_[0]); + while (!visiting_stack.empty()) { + BufferIntervalTreeNode* top = visiting_stack.back(); + visiting_stack.pop_back(); + if (start > top->subtree_end) { + continue; + } + if (top->left != nullptr) { + visiting_stack.push_back(top->left); + } + if (top->start <= end && top->end >= start) { + result.push_back(top->chunk); + } + if (end < top->start) { + continue; + } + if (top->right != nullptr) { + visiting_stack.push_back(top->right); + } + } + return result; + } + + private: + int64 node_count_ = 0; + std::vector node_storage_; +}; + +} // namespace + +HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { + std::vector sorted_buffer_intervals; + for (auto& entry : buffer_intervals_) { + sorted_buffer_intervals.push_back(entry.second); + } + std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); + + BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); + for (auto& buffer_interval : sorted_buffer_intervals) { + auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( + buffer_interval.start, buffer_interval.end); + std::sort( + chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); + + // Find the minimum free chunk that can hold this buffer. + Chunk min_fit_chunk{-1, INT64_MAX}; + auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) { + if (free_size < buffer_interval.size) { + return; + } + + if (free_size < min_fit_chunk.size) { + min_fit_chunk = {free_offset, free_size}; + } + }; + + int64 offset = 0; + for (auto& chunk : chunks_overlapping_in_time) { + if (offset < chunk.offset) { + use_free_chunk_if_smaller(offset, chunk.offset - offset); + } + offset = + std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_)); + } + use_free_chunk_if_smaller(offset, result_.heap_size - offset); + + if (min_fit_chunk.offset == -1) { + // Increase the heap size to fit in the last free chunk. + result_.heap_size = offset + buffer_interval.size; + min_fit_chunk = {offset, buffer_interval.size}; + } + + min_fit_chunk.size = buffer_interval.size; + const auto emplace_result = + result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk); + DCHECK(emplace_result.second); + + interval_tree.Add(buffer_interval.start, buffer_interval.end, + min_fit_chunk); + } + return result_; +} + +HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { + DCHECK(!algorithms_.empty()); + std::vector results(algorithms_.size()); + int64 min_size = INT64_MAX; + int min_size_index = -1; + for (int i = 0; i < algorithms_.size(); ++i) { + results[i] = algorithms_[i]->Finish(); + if (results[i].heap_size < min_size) { + min_size = results[i].heap_size; + min_size_index = i; + } + } + + DCHECK_GE(min_size_index, 0); + return results[min_size_index]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 811a6042df9434ac3f4bed71b9c093433e25c1bb..7d6dcc0dc9436ea6bd30ae14ffe226c014f1ca68 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -36,6 +37,7 @@ namespace xla { // Forward declare classes defined below. class HeapAlgorithm; +class NoFragmentationStatsHeap; // HeapSimulator assigns buffer offsets by running a simulation of a regular // memory heap with Alloc and Free calls. It only works for completely @@ -87,23 +89,22 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // module_sequence), assuming no fragmentation. + // schedule), assuming no fragmentation. static StatusOr MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // module_sequence, which must contain a topologically-consistent total + // schedule, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -111,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run( - std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr Run(std::unique_ptr algorithm, + const HloModule& module, + const HloSchedule& schedule, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -125,7 +126,7 @@ class HeapSimulator { static StatusOr Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), @@ -133,21 +134,19 @@ class HeapSimulator { memory_by_computation = nullptr); private: - // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator( - std::unique_ptr algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, - const tensorflow::gtl::FlatMap* - memory_by_computation = nullptr); + HeapSimulator(std::unique_ptr algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, const HloSchedule* schedule = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation( - const HloComputation& computation, - const std::vector& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation(const HloComputation& computation, + const HloInstructionSequence& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -161,15 +160,18 @@ class HeapSimulator { const HloInstruction* instruction, const BufferValue* shared_with_canonical); - const std::unique_ptr no_fragmentation_stats_; + // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, + // in which case we are calculating the same allocs/frees twice in the + // simulation. + const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // schedule_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const HloSchedule* schedule_; const tensorflow::gtl::FlatMap* memory_by_computation_; @@ -216,6 +218,21 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + // NoFragmentationStatsHeap overrides this method. + virtual void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + Alloc(buffer, size); + } + + // Takes memory usage of subcomputations into account when calculating the + // memory usage of a computation. Currently, we don't handle buffer aliasing + // between computations entirely correctly. We are careful to not double count + // for the output buffers of whiles/conds/calls. But we don't take into + // account other aliases, such as for the while init. A more thorough solution + // would require something like BufferAssignment::BuildColocatedBufferSets. + // TODO(b/65835246): + // Since TuplePointsToAnalysis is being replaced with a module-aware alias + // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& @@ -240,6 +257,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) override; + void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& @@ -331,6 +351,68 @@ class LazyBestFitHeap : public HeapAlgorithm { std::set free_; }; +// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers, +// then allocates them in decreasing sizes regardless of the alloc/free time. It +// internally tracks the allocated buffers and their live intervals; when +// allocating a buffer, it finds the best-fit free chunk during its live +// interval. +class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { + public: + GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {} + ~GlobalDecreasingSizeBestFitHeap() override {} + + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; + + private: + int64 alignment_; + Result result_; + + // The current time represented as an integer. It increments by 1 at each + // Alloc or Free call. + int64 current_time_ = 0; + + // BufferInterval stores a buffer's size and time interval. + struct BufferInterval { + const BufferValue* buffer; + int64 size; + // Alloc time of the buffer. + int64 start; + // Free time of the buffer. + int64 end; + }; + tensorflow::gtl::FlatMap + buffer_intervals_; +}; + +// A heap algorithm that chooses the best results from other algorithms added to +// it. +class ChooseBestHeapAlgorithm : public HeapAlgorithm { + public: + ChooseBestHeapAlgorithm( + std::unique_ptr>> algorithms) + : algorithms_(std::move(*algorithms)) {} + ~ChooseBestHeapAlgorithm() override {} + + void Alloc(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Alloc(buffer, size); + } + } + + void Free(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Free(buffer, size); + } + } + + Result Finish() override; + + private: + std::vector> algorithms_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b41dc66fe9f5e869a114be96b7cc01fc1a3d59da..191fbf8194ac65684cd7bfd48a6931d82c702186 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -28,13 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); @@ -84,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) - .ValueOrDie()); + HloSchedule schedule(module); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -137,7 +142,7 @@ class HeapSimulatorTracker { const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -146,17 +151,18 @@ class HeapSimulatorTracker { // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. auto zero_size = [](const BufferValue& buffer) { return 0; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -167,11 +173,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_.get()); tensorflow::gtl::FlatMap reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -182,10 +189,10 @@ class HeapSimulatorTracker { auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } @@ -226,7 +233,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloTestBase { +class HeapSimulatorTest : public HloVerifiedTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} @@ -365,8 +372,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -401,8 +408,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -439,10 +446,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -480,10 +487,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); @@ -675,7 +682,8 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); + buffers_.emplace_back( + absl::make_unique(id, const0, ShapeIndex{})); return buffers_.back().get(); } @@ -724,7 +732,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(DecreasingSizeRunsHeapTest, Empty) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Finish(); EXPECT_EQ(call_sequence, CallSequence({ {kFinish, nullptr}, @@ -733,7 +742,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) { TEST_F(DecreasingSizeRunsHeapTest, Simple) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -760,7 +770,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) { TEST_F(DecreasingSizeRunsHeapTest, Mixed) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Free(buffer_b_, 20); @@ -1010,5 +1021,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) { EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset); } +class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(0, result.heap_size); + EXPECT_EQ(0, result.chunk_map.size()); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { + // space + // ^ + // | +---a---+ + // | +-------+ + // | +---c---+ + // | +-------+ + // | | b | + // | +-------+ + // | +-------+ + // | | | + // | | d | + // | +-------+ + // -----------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 30); + heap.Alloc(buffer_c_, 20); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 30); + heap.Free(buffer_c_, 20); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(100, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | | + // | | d | + // | +---a---+ +-------+ + // | + // | +-------+ + // | | | + // | | c | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 50); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 50); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(120, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | d | + // | +--a--+ +-------+ + // | +-------+ + // | | | + // | | c | + // | +-------+ + // | +-------+ + // | | | + // | | e | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 40); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 30); + heap.Alloc(buffer_e_, 50); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 40); + heap.Free(buffer_d_, 30); + heap.Free(buffer_e_, 50); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(140, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0b93d97c11abd89c24c130849a8357806066fce7..b19ec126382d143b6ded401f2fad56f950d04bbd 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,6 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. +// Next ID: 53 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -45,6 +46,8 @@ message HloInstructionProto { reserved "control_predecessor_names"; reserved 6; reserved "called_computation_names"; + reserved 44; + reserved "replica_group_ids"; string name = 1; string opcode = 2; @@ -74,6 +77,11 @@ message HloInstructionProto { // Describes the dimension numbers used for a convolution. xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + // The number of feature groups. Used for a convolution. Must be a divisor of + // the input feature dimension and output feature dimension. If not specified, + // it will use a default value of 1. + int64 feature_group_count = 50; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; @@ -133,7 +141,7 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; - repeated int64 gather_window_bounds = 34; + repeated int64 gather_slice_sizes = 34; // Compute Host. string channel_name = 41; @@ -151,8 +159,8 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; - // Cross Replica Sum fields. - repeated int64 replica_group_ids = 44; + // Cross replica op fields. + repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; @@ -162,6 +170,12 @@ message HloInstructionProto { bool is_host_transfer = 47; xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; + + // Precision configuration for the instruction. Has backend-specific meaning. + xla.PrecisionConfig precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; } // Serialization of HloComputation. @@ -185,6 +199,17 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map sequences = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -200,16 +225,9 @@ message HloModuleProto { // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; } // Serialization of LogicalBuffer. @@ -291,6 +309,13 @@ message HeapSimulatorTrace { bool whole_module_simulation = 2; } +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that @@ -308,8 +333,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index e8a4b034b4396860bd5873f43003844ce92dea6c..0986da65cbd3d550ecfa01212364518aba651d86 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,15 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; // Data structure used to construct the alias analysis. Thrown away after alias // analysis is complete. This data structure keeps track of which sets of @@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const { } string HloAliasAnalysis::ToString() const { - string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); + string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { @@ -457,7 +455,7 @@ StatusOr> HloAliasAnalysis::Run( VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); - auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); + auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false, @@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference( if (ordering.MayInterfere(*values[i - 1], *values[i], dataflow_analysis())) { VLOG(1) << "In buffer " << buffer.id() << " containing values:\n " - << Join(values, ", ", - [](string* out, const HloValue* value) { - StrAppend(out, value->ToShortString()); - }) + << absl::StrJoin(values, ", ", + [](string* out, const HloValue* value) { + StrAppend(out, value->ToShortString()); + }) << "\nValue " << values[i - 1]->ToShortString() << " may interfere with value " << values[i]->ToShortString(); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 1fea544730c27efdaa260f55ea81c163165f7ed5..e345804537723f01e9ccb63e7d6ded1bd68f4196 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index da94ab5346e5628b4a603b3ac2d84071904d1e65..0cd0ab36fcf832af9a71ab5837c94f9b39bc4bf3 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,15 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloTestBase { +class HloAliasAnalysisTest : public HloVerifiedTestBase { protected: - HloAliasAnalysisTest() : module_(CreateNewModule()) {} + HloAliasAnalysisTest() : HloVerifiedTestBase() { + module_ = CreateNewModule(); + } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_.get(), + analysis_ = HloAliasAnalysis::Run(module_, /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -91,7 +93,7 @@ class HloAliasAnalysisTest : public HloTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -108,7 +110,7 @@ class HloAliasAnalysisTest : public HloTestBase { return false; } - std::unique_ptr module_; + HloModule* module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -461,7 +463,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_.get()).status()); + TF_ASSERT_OK(flattener.Run(module_).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -835,7 +837,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -877,24 +879,26 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } // For a sequential order, if there is interference iff the negate is after // the while. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[body] = {body_param, body_root}; - sequence[condition] = {cond_param, cond_root}; + HloSchedule schedule(module_); + schedule.set_sequence(body, {body_param, body_root}); + schedule.set_sequence(condition, {cond_param, cond_root}); { - sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_.get(), sequence); + schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_.get(), sequence); + schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index e16413f361fb0216792b47c3c67ef3c1357c2221..6c11a073b74c61e44dfe81a32261ae78ae7b46fb 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,15 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; - bool HloBuffer::operator==(const HloBuffer& other) const { bool equal = id() == other.id(); if (equal) { @@ -59,10 +56,11 @@ std::vector HloBuffer::ComputePositions() const { } string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return absl::StrCat( + "HloBuffer ", id_, ", values: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index 4873463b2ea4fee3ee39dff31fc3429a4998142f..a88c87e46c8100571aff24f70a2a19fe8ce71ebc 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -84,7 +84,7 @@ class HloBuffer { return a->id() == b->id(); } - HloBuffer(Id id, tensorflow::gtl::ArraySlice values) + HloBuffer(Id id, absl::Span values) : id_(id), values_(values.begin(), values.end()) {} // Return the unique identifier for this HloBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 441288da1a6859a3f393a298ee02eb4b435e42e0..0e5920af7a60966ace4ff52662cd23ea3141d477 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -23,9 +23,13 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -36,13 +40,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::StrCat; +using absl::StrCat; std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { @@ -56,8 +58,8 @@ std::unique_ptr HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, fusion_instruction_)); + return absl::WrapUnique(new HloComputation( + name_, parameter_count, &instructions_, root, fusion_instruction_)); } HloComputation::HloComputation( @@ -135,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) { } string after_param = original_name.substr(index + param_underscore.size()); int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + if (absl::SimpleAtoi(after_param, &numeric_suffix)) { return StrCat(original_name.substr(0, index + param_underscore.size()), new_param_no); } @@ -270,18 +272,19 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { << "instruction " << instruction->name() << " has control successors and cannot be removed"; - TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); - auto inst_it = instruction_iterators_.at(instruction); - (*inst_it)->set_parent(nullptr); - instructions_.erase(inst_it); + auto inst_it = instruction_iterators_.find(instruction); + TF_RET_CHECK(inst_it != instruction_iterators_.end()); + (*inst_it->second)->set_parent(nullptr); + instructions_.erase(inst_it->second); + instruction_iterators_.erase(inst_it); return Status::OK(); } -void HloComputation::set_root_instruction( - HloInstruction* new_root_instruction) { +void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!IsFusionComputation()) { + if (!IsFusionComputation() && !accept_different_shape) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape() << " is incompatible with " @@ -317,11 +320,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) { + tensorflow::gtl::FlatMap* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -354,16 +358,71 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } + } + } + break; + } + default: + break; + } } } -} // namespace +HloComputation::ChannelDependencyMap +HloComputation::ComputeChannelDependencies() const { + ChannelDependencyMap channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} std::vector HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -371,7 +430,8 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -405,6 +465,14 @@ std::vector HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -427,7 +495,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } @@ -493,13 +563,15 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + auto computation = absl::WrapUnique( + new HloComputation(proto.name(), parameter_count, &instructions, root, + /*fusion_instruction=*/nullptr)); + computation->unique_id_ = proto.id(); + return std::move(computation); } void HloComputation::FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction) { CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); HloInstruction* root = instructions_to_fuse.front(); @@ -518,7 +590,7 @@ void HloComputation::FuseInstructionsInto( } HloInstruction* HloComputation::CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind) { HloInstruction* root = instructions_to_fuse.front(); HloInstruction* fusion_instruction = AddInstruction( @@ -566,16 +638,15 @@ StatusOr HloComputation::DeepCopyInstruction( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " "has incompatible shapes: %s vs. %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanString(indices_to_copy->shape())); } ShapeIndex index; @@ -605,7 +676,7 @@ StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } ShapeIndex index; return DeepCopyHelper(instruction, &index, copy_leaf); @@ -624,6 +695,9 @@ ProgramShape HloComputation::ComputeProgramShape() const { } bool HloComputation::operator==(const HloComputation& other) const { + if (this == &other) { + return true; + } std::set> visited; std::function eq = [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { @@ -674,13 +748,37 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); - auto result = MakeUnique(all); + auto result = absl::make_unique(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; @@ -723,11 +821,10 @@ std::vector HloComputation::CollectUnreachableRoots() const { } } VLOG(3) << "Unreachable roots:" - << tensorflow::str_util::Join( - unreachable_roots, "\n\t", - [](string* out, const HloInstruction* hlo) { - tensorflow::strings::StrAppend(out, hlo->ToString()); - }); + << absl::StrJoin(unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + absl::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -820,16 +917,17 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - context, suffix); + /*extras=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloCloneContext* context, const string& suffix) { + absl::Span extras, HloCloneContext* context, + const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { - context_ptr = MakeUnique(parent(), suffix); + context_ptr = absl::make_unique(parent(), suffix); context = context_ptr.get(); } @@ -848,6 +946,9 @@ std::unique_ptr HloComputation::CloneWithReplacements( VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; std::vector postorder; + for (HloInstruction* instr : extras) { + postorder.push_back(instr); + } for (HloInstruction* instr : MakeInstructionPostOrder()) { if (HloInstruction* replacement = replace(instr)) { postorder.push_back(replacement); @@ -898,12 +999,11 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } -HloInstruction* HloComputation::GetInstructionWithName( - tensorflow::StringPiece name) { +HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { auto instructions_in_computation = instructions(); - auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { - return instr->name() == name; - }); + auto it = absl::c_find_if( + instructions_in_computation, + [&](HloInstruction* instr) { return instr->name() == name; }); return it == instructions_in_computation.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 49ed65910f519810740b89760ad815f287e59a91..936a53bd7e9ad362d10f06ab807ddb8944fec93e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -134,9 +134,11 @@ class HloComputation { Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); // Set the root of the computation to the given instruction. The instruction - // must have already been added to the computation and have the same shape as - // the result of the computation for non fusion computations. - void set_root_instruction(HloInstruction* new_root_instruction); + // must have already been added to the computation. In addition it must have + // the same shape as the result of the computation for non fusion + // computations, except if accept_different_shape is set to true. + void set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape = false); // Return the root instruction of the computation. The root instruction is the // instruction which produces the output of the computation. @@ -170,6 +172,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; @@ -220,7 +227,7 @@ class HloComputation { void UpdateReachabilityThroughInstruction( const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instructions_.size(); } + int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this // computation. This includes all embedded computations called directly or @@ -237,7 +244,7 @@ class HloComputation { // removed if they have no uses after fusion (this is necessarily true for at // least the root). HloInstruction* CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind); // Create a deep copy of the given instruction and return the instruction @@ -326,10 +333,13 @@ class HloComputation { // // If replacements maps a key to nullptr, we remove that instruction from the // new computation. + // If additional instructions are used by instructions in replacement map, + // they must be passed in post-order in the extras span. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloCloneContext* context = nullptr, const string& suffix = "clone"); + absl::Span extras, HloCloneContext* context = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -367,7 +377,7 @@ class HloComputation { // Returns the instruction in this computation that has name `name`. Returns // null if there is no such computation. - HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + HloInstruction* GetInstructionWithName(absl::string_view name); int64 unique_id() const { return unique_id_; } @@ -385,7 +395,7 @@ class HloComputation { // // Pre-condition: fusion_instruction's opcode is kFusion. void FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction); // Internal helper for recursive copying of an instruction. Creates and @@ -399,6 +409,20 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + tensorflow::gtl::FlatMap>; + ChannelDependencyMap ComputeChannelDependencies() const; + + enum VisitState { kVisiting, kVisited }; + void ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) const; + string name_; int64 unique_id_; HloInstruction* root_instruction_; @@ -415,7 +439,7 @@ class HloComputation { // instruction pointer to location in the list for fast lookup. using InstructionList = std::list>; InstructionList instructions_; - std::unordered_map + tensorflow::gtl::FlatMap instruction_iterators_; std::vector param_instructions_; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index e4c547033139185d5dd4ef37db2d22a6431c1102..2aaaef1d36d58bcce18db4aa37ff05ea352e484b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -691,6 +700,27 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -} // namespace +TEST_F(HloComputationTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = computation->ComputeReachability(); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + +} // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 7229031c0c7f8bd374cfb495c7d8c11e9ca8b95e..f837816cea78d78bb3d605dd91e81cac39036268 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -38,7 +39,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may // be revised. - auto evaluator = MakeUnique(/*max_loop_iterations=*/0); + auto evaluator = absl::make_unique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); @@ -51,9 +52,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce, and AfterAll operation. - // TODO(b/35975797): Enable Reduce operation once arbitrary computation - // are supported by the evaluator. + // Skip Constant, Parameter, and AfterAll operation. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this @@ -61,7 +60,6 @@ StatusOr HloConstantFolding::Run(HloModule* module) { if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kAfterAll) { continue; } @@ -73,14 +71,15 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. - if (instruction->opcode() == HloOpcode::kBroadcast) { + if (instruction->opcode() == HloOpcode::kBroadcast || + instruction->opcode() == HloOpcode::kIota) { continue; } - std::unique_ptr result = evaluator->TryEvaluate(instruction); + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. - if (result == nullptr) { + if (!evaluator->TryEvaluate(instruction, &result)) { VLOG(2) << "Constant folding failed for instruction: " << instruction->ToString(); continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 331480bd029727fa15476cb9ced2e7b7afd170f3..4a624cc7b8483aaa834634185a23195e437bd4e4 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -23,9 +23,9 @@ namespace xla { // A pass which performs constant folding in order to avoid unnecessary // computation on constants. -class HloConstantFolding : public HloPassInterface { +class HloConstantFolding : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "constant_folding"; } + absl::string_view name() const override { return "constant_folding"; } // Run constant folding operations on the given module. Returns whether the // module was changed (constant expressions folded). diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 64a42c1efc0c788ae8e66fb72b2d9aecec179082..3e0def5d26a0033d954a776c1c32d6c35acfb505 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -24,10 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -36,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloVerifiedTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -51,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -72,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -93,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -104,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { TEST_F(HloConstantFoldingTest, Concatenate) { const struct TestConfig { int concat_dimension; - tensorflow::gtl::ArraySlice dimensions; - tensorflow::gtl::ArraySlice concat_sizes; + absl::Span dimensions; + absl::Span concat_sizes; } test_configs[] = { {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, @@ -133,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -160,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -174,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -185,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -195,12 +196,51 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; root->literal().EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get(rindexes)); + matched = matched && (value == literal_clone.Get(rindexes)); }); EXPECT_TRUE(matched); } +const char* const kConstantFoldReduce = R"( + HloModule ConstantFoldReduce + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + ENTRY r { + x = s32[3] constant({1, 2, 3}) + init = s32[] constant(0) + ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add + })"; + +TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { + ParseAndVerifyModule(kConstantFoldReduce); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + EXPECT_TRUE(result); + + EXPECT_EQ(6, module() + .entry_computation() + ->root_instruction() + ->literal() + .GetFirstElement()); +} + +TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { + ParseAndVerifyModule(kConstantFoldReduce); + HloInstruction* add = module().computations().begin()->root_instruction(); + LayoutUtil::ClearLayout(add->mutable_shape()); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + EXPECT_FALSE(result); + + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index a2cefd26211eb9f09e8668a7fad9f8085ab0cd6a..a502fff9a0f1e40065746f2193bf76b1adefdb31 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { + // Domain does not have any computation or data transfer. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); @@ -258,10 +266,6 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { - return Status::OK(); -} - Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, @@ -278,15 +282,21 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) { } Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { - auto arg = reduce->operand(0); HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. - int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - - ShapeUtil::ElementsIn(reduce->shape()); + // This counts the number of times the reduction function is applied, so it + // does not need to be multiplied by the number of input tensors - that's + // already "priced in" by the sub-computation doing more work. + auto arg = reduce->operand(0); + auto output_shape = ShapeUtil::IsArray(reduce->shape()) + ? reduce->shape() + : reduce->shape().tuple_shapes(0); + int64 reduction_count = + ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * reduction_count; @@ -505,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { valid_position_counts.push_back(valid_position_count); } - const int64 fma_count = - input_feature * output_feature * batch * Product(valid_position_counts); + const int64 fma_count = (input_feature / convolution->feature_group_count()) * + output_feature * batch * + Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } @@ -543,6 +554,14 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { return Status::OK(); } +Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 0a79c92f4a95f6337c8c25b47f6967fc9ff3fd98..46b4bbeef222e6de581360fc01b293e812f1dedd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -67,13 +67,15 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; + Status HandleDomain(const HloInstruction* domain) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; - Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18642eb7cb081b4fdfe3bc83627e41ae..d76ce9ecbca67ae3bc3db4ee2452f30ccec5b88b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) { sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } +TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { + XlaBuilder builder("convolution"); + auto input = Parameter( + &builder, 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = Parameter( + &builder, 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x120x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18)); +} + TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = @@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 90d2be118d94d52135820e5b8138fcb06389c684..b2005d3c210d4ae7e3702cb9624c3ad98056984c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,15 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; +using absl::StrCat; StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -48,9 +50,9 @@ StatusOr MakePadHlo(HloInstruction* operand, } StatusOr MakeSliceHlo(HloInstruction* operand, - ArraySlice start_indices, - ArraySlice limit_indices, - ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( operand->shape(), start_indices, @@ -60,19 +62,22 @@ StatusOr MakeSliceHlo(HloInstruction* operand, } StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), - window, dimension_numbers)); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), feature_group_count, + window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, window, dimension_numbers)); + convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, - ArraySlice dimensions) { + absl::Span dimensions) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( Shape transpose_shape, @@ -89,15 +94,15 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, } StatusOr MakeReshapeHlo( - ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + absl::Span result_shape_dim_bounds, HloInstruction* operand) { Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_dim_bounds); return MakeReshapeHlo(new_shape, operand); } -StatusOr MakeDynamicSliceHlo(HloInstruction* operand, - HloInstruction* start_indices, - ArraySlice slice_sizes) { +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); TF_ASSIGN_OR_RETURN( @@ -123,8 +128,8 @@ StatusOr MakeDynamicUpdateSliceHlo( } StatusOr MakeBroadcastHlo( - HloInstruction* operand, ArraySlice broadcast_dimensions, - ArraySlice result_shape_bounds) { + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -144,18 +149,18 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); } -StatusOr MakeConcatHlo(ArraySlice operands, - int64 dimension) { +StatusOr MakeConcatHlo( + absl::Span operands, int64 dimension) { CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); - CHECK(c_all_of(operands, [&](HloInstruction* instr) { + CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) { return instr->parent() == computation; })); std::vector operand_shapes; - c_transform(operands, std::back_inserter(operand_shapes), - [](HloInstruction* instr) { return &instr->shape(); }); + absl::c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( operand_shapes, dimension)); @@ -164,14 +169,75 @@ StatusOr MakeConcatHlo(ArraySlice operands, } StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers) { + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( Shape dot_shape, ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + return computation->AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dim_numbers, precision_config)); +} + +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation) { + CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; + HloComputation* computation = operands.front()->parent(); + std::vector operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : operands) { + CHECK_EQ(computation, operand->parent()); + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + TF_ASSIGN_OR_RETURN( + Shape map_shape, + ShapeInference::InferMapShape( + operand_shapes, map_computation->ComputeProgramShape(), map_dims)); return computation->AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); + HloInstruction::CreateMap(map_shape, operands, map_computation)); +} + +StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module) { + DCHECK_NE(nullptr, module); + std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::iota(all_dims.begin(), all_dims.end(), 0); + + auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); + HloComputation* reduce_computation; + { + HloComputation::Builder b(operand->name() + ".reduce_sub_computation"); + auto lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + b.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); + reduce_computation = module->AddEmbeddedComputation(b.Build()); + } + + return operand->parent()->AddInstruction(HloInstruction::CreateReduce( + scalar_shape, operand, init_value, all_dims, reduce_computation)); +} + +StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + HloComputation* computation = pred->parent(); + DCHECK_EQ(computation, on_true->parent()); + DCHECK_EQ(computation, on_false->parent()); + TF_ASSIGN_OR_RETURN(Shape select_shape, + ShapeInference::InferTernaryOpShape( + HloOpcode::kSelect, pred, on_true, on_false)); + return computation->AddInstruction(HloInstruction::CreateTernary( + select_shape, HloOpcode::kSelect, pred, on_true, on_false)); } StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { @@ -205,19 +271,19 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, const Shape& operand_shape = operand->shape(); new_shape_dims.reserve(n + operand_shape.dimensions_size()); new_shape_dims.insert(new_shape_dims.begin(), n, 1); - c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); return MakeReshapeHlo(new_shape_dims, operand); } StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, ArraySlice expanded_dims) { + HloInstruction* operand, absl::Span expanded_dims) { CHECK_GT(operand->shape().dimensions_size(), 0); CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); std::vector expanded_shape_dim_bounds; expanded_shape_dim_bounds.reserve(expanded_dims.size() + operand->shape().dimensions_size() - 1); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); std::copy(operand->shape().dimensions().begin() + 1, operand->shape().dimensions().end(), std::back_inserter(expanded_shape_dim_bounds)); @@ -226,9 +292,9 @@ StatusOr ExpandFirstDimIntoNDims( return MakeReshapeHlo(new_shape, operand); } -StatusOr ElideDegenerateDims(HloInstruction* operand, - ArraySlice dims_to_elide) { - CHECK(c_is_sorted(dims_to_elide)); +StatusOr ElideDegenerateDims( + HloInstruction* operand, absl::Span dims_to_elide) { + CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); // First accumulate in reverse @@ -245,12 +311,44 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, } } - c_reverse(new_shape_dim_bounds); + absl::c_reverse(new_shape_dim_bounds); Shape output_shape = ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); return MakeReshapeHlo(output_shape, operand); } +StatusOr InsertDegenerateDims( + HloInstruction* operand, absl::Span dims_to_insert) { + CHECK(absl::c_is_sorted(dims_to_insert)); + + const Shape& operand_shape = operand->shape(); + int64 output_shape_rank = + operand_shape.dimensions_size() + dims_to_insert.size(); + for (auto dim_to_insert : dims_to_insert) { + CHECK_LT(dim_to_insert, output_shape_rank); + } + + std::vector output_shape_dim_bounds; + output_shape_dim_bounds.reserve(output_shape_rank); + int64 operand_dims_idx = 0; + int64 dims_to_insert_idx = 0; + for (int64 i = 0; i < output_shape_rank; ++i) { + if (dims_to_insert_idx < dims_to_insert.size() && + i == dims_to_insert[dims_to_insert_idx]) { + output_shape_dim_bounds.push_back(1); + ++dims_to_insert_idx; + } else { + output_shape_dim_bounds.push_back( + operand_shape.dimensions(operand_dims_idx)); + ++operand_dims_idx; + } + } + + Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + output_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + StatusOr PadVectorWithZeros(HloInstruction* operand, int64 zeros_to_prepend, int64 zeros_to_append) { @@ -262,26 +360,25 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( - LiteralUtil::Zero(operand->shape().element_type())))); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(operand->shape().element_type()))); return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - ArraySlice broadcast_dimensions) { - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::Span broadcast_dimensions) { + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } StatusOr> CreateComputationWithSignature( - ArraySlice domain, const Shape& range, - tensorflow::StringPiece name) { - HloComputation::Builder b{std::string(name)}; + absl::Span domain, const Shape& range, + absl::string_view name) { + HloComputation::Builder b{string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 49b1402d689a74874e34423a1832a0b6aa15f469..8e5ddbbd503a501bd493aec43a2ccd4db883ef0c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -40,21 +41,22 @@ StatusOr MakePadHlo(HloInstruction* operand, // Creates a slice HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeSliceHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +StatusOr MakeSliceHlo(HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeTransposeHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); +StatusOr MakeTransposeHlo(HloInstruction* operand, + absl::Span dimensions); // Creates a reshape HLO instruction and adds it to the computation containing // `operand`. @@ -62,15 +64,14 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, HloInstruction* operand); StatusOr MakeReshapeHlo( - tensorflow::gtl::ArraySlice result_shape_dim_bounds, - HloInstruction* operand); + absl::Span result_shape_dim_bounds, HloInstruction* operand); // Creates a dynamic-slice HLO instruction and adds it to the computation // containing `operand` and `start_indices` (`operand` and `start_indices` must // be in the same computation). StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Creates a dynamic-update-slice HLO instruction and adds it to the computation // containing `operand`, `update` and `start_indices` (`operand`, `update` and @@ -82,9 +83,8 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. StatusOr MakeBroadcastHlo( - HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions, - tensorflow::gtl::ArraySlice result_shape_bounds); + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -95,12 +95,47 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, // containing `operands` (`operands` must be non-empty and every element must be // contained in the same computation). StatusOr MakeConcatHlo( - tensorflow::gtl::ArraySlice operands, int64 dimension); + absl::Span operands, int64 dimension); // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers); + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config); + +// Creates a Map HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation); + +// Creates a Reduce HLO instruction and adds it to the computation containing +// the operand. This will create the sub-computation needed for the reduction in +// the given module. binary_opcode should represent a binary operation. +StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module); + +// Creates a Select HLO instruction and adds it to the computation containing +// the predicate. The on_true and on_false instructions must also be contained +// in the same computation. +StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false); + +// Creates an R1 Constant HLO instruction of the given PrimitiveType with the +// given values and adds it to the given computation. +template +StatusOr MakeR1ConstantHlo(HloComputation* computation, + PrimitiveType type, + absl::Span values) { + Literal literal = LiteralUtil::CreateR1(values); + if (literal.shape().element_type() != type) { + TF_ASSIGN_OR_RETURN(literal, literal.Convert(type)); + } + return computation->AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); +} // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of @@ -132,7 +167,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, // For instance if `operand` has shape f32[200,9,7] and expanded_dims is // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + HloInstruction* operand, absl::Span expanded_dims); // Elides (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_elide` from `operand`. Every dimension in @@ -142,7 +177,17 @@ StatusOr ExpandFirstDimIntoNDims( // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. StatusOr ElideDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + HloInstruction* operand, absl::Span dims_to_elide); + +// Inserts (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_insert` into `operand`. The dimensions in +// `dims_to_insert` refer to the dimensions in the result, and hence should be +// less than the rank of the result. Also, `dims_to_insert` must be sorted. +// +// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is +// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. +StatusOr InsertDegenerateDims( + HloInstruction* operand, absl::Span dims_to_insert); // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. @@ -155,13 +200,13 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // broadcast instruction is emitted into `computation`. StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. StatusOr> CreateComputationWithSignature( - tensorflow::gtl::ArraySlice domain, const Shape& range, - tensorflow::StringPiece name); + absl::Span domain, const Shape& range, + absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 60d3e71757d5ce31e025c744e089ff56091d9a43..e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -14,23 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -using tensorflow::gtl::ArraySlice; -class HloCreationUtilsTest : public HloTestBase { +class HloCreationUtilsTest : public HloVerifiedTestBase { protected: - static std::unique_ptr CreateModuleWithProgramShape( - PrimitiveType primitive_type, ArraySlice input_shape_dims, - ArraySlice output_shape_dims, HloInstruction** param, + HloModule* CreateModuleWithProgramShape( + PrimitiveType primitive_type, absl::Span input_shape_dims, + absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = @@ -48,27 +47,27 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); @@ -79,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2( + CHECK_EQ(result_literal, + LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -93,10 +92,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -104,17 +103,17 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, &entry_computation); @@ -125,37 +124,37 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(9)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, &entry_computation); @@ -166,21 +165,21 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -188,20 +187,20 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -209,20 +208,20 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(0)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - F32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(F32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -230,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR0(0.0f)})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 06484f4012fc091f70df7bc8ec231ce3fcf89669..b59c9ba3ed7990eb2a35abc83f87b25a1b1e7c60 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -103,6 +104,9 @@ int64 CseHash(const HloInstruction* instruction) { for (auto operand : instruction->operands()) { hash = tensorflow::Hash64Combine(hash, operand->unique_id()); } + if (instruction->opcode() == HloOpcode::kConstant) { + hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + } return hash; } diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 5e2b348bdda2b31556fb692e24d2bad2e4173ef5..e4857fd3fdd9a329b013ac8215cb6d36d73c4b7d 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -25,7 +25,7 @@ namespace xla { // and identical instructions with the same operands are commoned. The pass // iterates over the instructions in topological order which enables the pass to // find arbitrarily large common expressions. -class HloCSE : public HloPassInterface { +class HloCSE : public HloModulePass { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. @@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface { : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations) {} ~HloCSE() override = default; - tensorflow::StringPiece name() const override { return "cse"; } + absl::string_view name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common // subexpressions were found and eliminated). diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 90fbaa37c5a70a78a9a818b4a8968f3406c671b1..9b18b0284f63c25934c1b7118dc8973caa62cadc 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -20,16 +20,16 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloTestBase { +class HloCseTest : public HloVerifiedTestBase { protected: HloCseTest() {} }; @@ -65,15 +65,15 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get({})); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0(84.0); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -96,16 +96,16 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); EXPECT_THAT(add, op::Add(first_operand, first_operand)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -128,14 +128,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { @@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })") - .ValueOrDie(); + })"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule m add_computation { @@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })") - .ValueOrDie(); + })"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -735,13 +730,11 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})") - .ValueOrDie(); +})"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - LOG(INFO) << "AAAAA " << module->ToString(); - const HloInstruction* sub = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + const HloInstruction* sub = module().entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bbfb0c253f583b633c4b2c34b2f068b563d3d9e0..6a63681996bc57f4ef16b2405ffc8ce4f003e783 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -29,8 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -46,8 +46,7 @@ namespace { // // In this case, we should be able to reuse p0 and output, although p0 has // multiple uses. -bool MultiDynamicSliceUseShareSameIndices( - tensorflow::gtl::ArraySlice uses) { +bool MultiDynamicSliceUseShareSameIndices(absl::Span uses) { if (uses.empty()) { return false; } @@ -78,8 +77,8 @@ bool MultiDynamicSliceUseShareSameIndices( } // namespace -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; HloDataflowAnalysis::HloDataflowAnalysis( const HloModule& module, bool ssa_form, bool bitcast_defines_value, @@ -93,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { tensorflow::gtl::FlatSet visited; - tensorflow::gtl::InlinedVector stack; + absl::InlinedVector stack; stack.push_back(inst); while (!stack.empty()) { const HloInstruction* current = stack.back(); @@ -221,7 +220,7 @@ string HloDataflowAnalysis::ToString() const { bool HloDataflowAnalysis::Phi( HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; VLOG(5) << "instruction value set = " @@ -837,7 +836,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Unimplemented( "Computation %s is called in both a parallel (eg, kMap) and " "sequential (eg, kCall) context", - computation->name().c_str()); + computation->name()); } if (call_graph_node.caller_callsites().empty() || call_graph_node.context() == CallContext::kParallel) { @@ -886,7 +885,7 @@ StatusOr> HloDataflowAnalysis::Run( VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis( module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); @@ -976,28 +975,22 @@ Status HloDataflowAnalysis::Verify() const { bool HloDataflowAnalysis::DoesNotUseOperandBuffer( const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + HloInstruction* fusion_param = + user->fused_parameter(use.operand_number); + const HloValue& value = + GetValueDefinedAt(fusion_param, use.operand_index); + return value.uses().empty(); } + return false; } } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index f4abc7a7c7dcfb223067fe946bec0c5ef32f206b..e62c1c2ac81981e1f44f4c7e1479107979576e32 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -138,7 +138,8 @@ class HloDataflowAnalysis { // Returns true if 'user' cannot possibly use the buffer at 'index' in // 'operand'. Returns false otherwise. // - // REQUIRES: 'operand' is an operand of 'user'. + // 'operand' does not have to be an operand of 'user'. This can be the case + // with indirect uses. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const; @@ -201,7 +202,7 @@ class HloDataflowAnalysis { // the given instruction. If skip_top_level is true, then the top level of the // value set of 'instruction' is not modified. bool Phi(HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs); + absl::Span inputs); // Updates the positions of the HloValues in the output of the given // instruction. This should be called after the instruction value set of diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4755c4a0cf8d268b1c47e596a14605eb2c60b36c..510d6360a1cf94ef06d2ed919a57c7a825886834 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param, xla_while}}); - sequence.insert({condition, {cond_param, cond_constant}}); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, xla_while}); + schedule.set_sequence(condition, {cond_param, cond_constant}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - sequence.insert({body, {constant, exp, body_param, add}}); + schedule.set_sequence( + body, {constant, exp, body_param, add, dead_constant, dead_negate}); + TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(schedule); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - std::vector order = {param, negate, exp, add}; - sequence.emplace(entry, order); - - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, negate, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); @@ -1963,6 +1966,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); } +// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the +// parameter tuple. +TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto t0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0)); + auto t1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1)); + // Swap the tuple elements. + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0})); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); + // The same holds for the parameter tuple, except that the tuple elements are + // swapped in 'tuple'. + EXPECT_TRUE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion)); + EXPECT_FALSE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion)); +} + class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { @@ -2286,8 +2337,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 4e244494d6f98c48f4376bd762f116b9a9c2084d..401204267282b294ca9f701e29e9edd9f0f35b98 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -33,10 +33,10 @@ namespace xla { // // This pass does not remove dead parameter instructions, as parameter // instructions cannot be deleted. -class HloDCE : public HloPassInterface { +class HloDCE : public HloModulePass { public: ~HloDCE() override {} - tensorflow::StringPiece name() const override { return "dce"; } + absl::string_view name() const override { return "dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 26e3736e01270dbc6ca67647e814843aba2d1e3d..3b5cde2996c4195ef458662cd21de85a832d8d55 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 78955db0da02f16eb93689db947dc1190ab7049a..72185698c9bdcbf2bebed7ee82bc4ed082ce6a14 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext { StatusOr Run(); private: - // Inserts a kDomain instruction between parent and operand, in case - // the attribute (ie, sharding) values change between instruction and operand. - // Returns the newly inserted kDomain instruction, or nullptr if no kDomain - // instruction was necessary. - StatusOr CreateDomain(HloInstruction* instruction, - HloInstruction* parent, - HloInstruction* operand); - HloModule* module_; HloDomainIsolator* isolator_; }; -StatusOr HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* parent, - HloInstruction* operand) { - HloInstruction* domain = nullptr; - std::unique_ptr domain_instruction = - isolator_->creator_(instruction, operand); - if (domain_instruction != nullptr) { - domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); - } - return domain; -} - StatusOr HloDomainIsolator::RunContext::Run() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); @@ -71,16 +50,16 @@ StatusOr HloDomainIsolator::RunContext::Run() { // When applying multiple domains, we could end up stacking more than // one in one edge, so here we want to build the effective // (kDomain-less) instruction->operand edge. - HloInstruction* parent = instruction; - while (operand->opcode() == HloOpcode::kDomain) { - parent = operand; - operand = operand->mutable_operand(0); + HloInstruction* root = operand; + while (root->opcode() == HloOpcode::kDomain) { + root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, parent, operand)); + HloInstruction* domain = + isolator_->creator_(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); ++added_domains; } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index eded3e78eead76c4564daee119034c5031eba409..c0bf1b9e16b52d81365db277abeb06defeb12d44 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -30,18 +30,20 @@ namespace xla { // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no // kDomain instruction will be placed. -class HloDomainIsolator : public HloPassInterface { +class HloDomainIsolator : public HloModulePass { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the - // second HloInstruction argument). + // third HloInstruction argument) if the interesting attribute of the + // instruction differes from the attribute of the root (the second + // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. - using DomainCreator = std::function( - HloInstruction*, HloInstruction*)>; + using DomainCreator = std::function; explicit HloDomainIsolator(DomainCreator creator); - tensorflow::StringPiece name() const override { return "domain_isolator"; } + absl::string_view name() const override { return "domain_isolator"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 9e096320db5048457435199627a1ef1fe1572177..113fd18eae70f0a581e2ab3e44544c47fcab3361 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -25,14 +26,14 @@ namespace xla { /* static */ StatusOr> HloDomainMap::Create( HloComputation* computation, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); TF_RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } /* static */ StatusOr> HloDomainMap::Create( HloModule* module, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { TF_RETURN_IF_ERROR(domain_map->Populate(computation)); } @@ -50,20 +51,24 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } +int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { + return FindOrDie(domain_metadata_id_, instruction); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from // both sides. for (HloInstruction* operand : instruction->unique_operands()) { if (IsDomainInstruction(operand)) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(operand); domain->exit_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } if (instruction == instruction->parent()->root_instruction()) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } @@ -71,6 +76,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { } Status HloDomainMap::Populate(HloComputation* computation) { + InstructionOrderMap instructions_post_order; + int64 count = 0; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + instructions_post_order.insert(std::make_pair(instruction, count++)); + } for (HloInstruction* instruction : computation->instructions()) { if (IsDomainInstruction(instruction)) { // If this is a kDomain of the kind we are currently processing, check @@ -84,9 +94,46 @@ Status HloDomainMap::Populate(HloComputation* computation) { continue; } TF_ASSIGN_OR_RETURN(std::unique_ptr domain, - CreateDomain(instruction)); + CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } + TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + return Status::OK(); +} + +Status HloDomainMap::PopulateDomainMetadataMap() { + auto hash = [](const DomainMetadata* m) { return m->Hash(); }; + auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { + return a->Matches(*b); + }; + tensorflow::gtl::FlatMap + domain_metadata(1024, hash, equal); + + for (auto& domain : instruction_domains_) { + int64 domain_metadata_id = -1; + if (!domain->enter_domains.empty()) { + const HloInstruction* domain_instruction = *domain->enter_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->user_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else if (!domain->exit_domains.empty()) { + const HloInstruction* domain_instruction = *domain->exit_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->operand_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else { + domain_metadata_id = 0; + } + TF_RET_CHECK(domain_metadata_id >= 0); + for (HloInstruction* instruction : domain->instructions) { + domain_metadata_id_[instruction] = domain_metadata_id; + } + } return Status::OK(); } @@ -142,10 +189,12 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, } StatusOr> HloDomainMap::CreateDomain( - HloInstruction* instruction) const { - auto domain = MakeUnique(); + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const { + auto domain = absl::make_unique(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); - domain->instructions = MakeNonDomainInstructions(domain->reach_set); + domain->instructions = + MakeNonDomainInstructions(domain->reach_set, instructions_order); return std::move(domain); } @@ -167,7 +216,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set) { + const tensorflow::gtl::FlatSet& instruction_set, + const InstructionOrderMap& instructions_order) { std::vector instructions; instructions.reserve(instruction_set.size()); for (HloInstruction* instruction : instruction_set) { @@ -175,9 +225,10 @@ HloDomainMap::MakeNonDomainInstructions( instructions.push_back(instruction); } } + // sort instructions according to instructions_order std::sort(instructions.begin(), instructions.end(), - [](HloInstruction* a, HloInstruction* b) { - return a->unique_id() < b->unique_id(); + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 1ca71597253eecfb45ae8f384240033a57045277..56b557d7cea424f63cd4891661ae446133ee5a37 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -69,7 +69,17 @@ class HloDomainMap { // instruction is not found within any domain. int64 GetDomainId(HloInstruction* instruction) const; + // Returns the unique id of the domain metadata for the domain the given + // instruction belongs to. The given instruction must not be a kDomain + // instruction since each domain instruction is associated with 2 domains. + int64 GetDomainMetadataId(HloInstruction* instruction) const; + private: + // Map used for representing instruction ordering, i.e. + // order_map[a] < order_map[b] means a must be ordered before b. + using InstructionOrderMap = + tensorflow::gtl::FlatMap; + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} // Check if the kDomain instruction is facing (via its operand link) another @@ -95,16 +105,23 @@ class HloDomainMap { // Creates a domain data structure using the ExpandDomain() API. StatusOr> CreateDomain( - HloInstruction* instruction) const; + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const; // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set); + const tensorflow::gtl::FlatSet& instruction_set, + const InstructionOrderMap& instructions_order); + + // Populates domain_metadata_id_ that maps each HloInstruction to the unique + // ID of its associated domain metatadata. + Status PopulateDomainMetadataMap(); string domain_kind_; std::vector> instruction_domains_; tensorflow::gtl::FlatMap instruction_to_domain_; + tensorflow::gtl::FlatMap domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index f855f2a1fc944fcc11c9afed278bef4af87813da..302807f816e4ab626af419023e7740fd6bde795f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -44,7 +44,10 @@ class DomainMetadata { // two domains of different kind intersect each other. tensorflow::gtl::FlatSet reach_set; - // The same instructions in reach_set, but purged from kDomain instructions. + // The same instructions in reach_set, but purged from kDomain instructions + // and ordered according to their computation graph post-order, i.e. + // if instructions[pos_a] depends on instructions[pos_b], then pos_a > + // pos_b. std::vector instructions; // If we consider a graph edge as an arrow oriented from the operand to the @@ -63,12 +66,15 @@ class DomainMetadata { // Returns the metadata type. A unique identifier which describes the real // metadata type. - virtual tensorflow::StringPiece Kind() const = 0; + virtual absl::string_view Kind() const = 0; // Compares the metadata object with another one and returns true if the // two matches. virtual bool Matches(const DomainMetadata& other) const = 0; + // Returns the hash value of the metadata. + virtual size_t Hash() const = 0; + // Returns a string representation of the metadata. virtual string ToString() const = 0; }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index c859e05f02e54d601804b641094ecdd11bbe1aed..0fc30fb86c337a8bba5957d504caa7deeac9b86c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -26,7 +26,7 @@ namespace xla { // Removes all the kDomain instructions of a given kind from the input module, // and calls the normalizer to propagate the properties on the possibly new born // instructions. -class HloDomainRemover : public HloPassInterface { +class HloDomainRemover : public HloModulePass { public: // Creates a new HloDomainRemover object tasked at removing all the kDomain // instructions of a given kind. @@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover(tensorflow::StringPiece kind, + HloDomainRemover(absl::string_view kind, std::function normalizer) - : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + : kind_(kind), normalizer_(std::move(normalizer)) {} - tensorflow::StringPiece name() const override { return "domain_remover"; } + absl::string_view name() const override { return "domain_remover"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index ffc18a0f886df86d87944d9c284a6faf8afe4c60..43e74d2f6f07bd685ad8683401138a4f06cd2ad2 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -45,9 +46,8 @@ class HloDomainTest : public HloVerifiedTestBase { // Checks whether there is a kDomain instruction in the edge between the // instruction and the operand. - bool HasDomainEdge(HloModule* module, - tensorflow::StringPiece instruction_name, - tensorflow::StringPiece operand_name) { + bool HasDomainEdge(HloModule* module, absl::string_view instruction_name, + absl::string_view operand_name) { HloInstruction* instruction = FindInstruction(module, instruction_name); HloInstruction* operand = FindInstruction(module, operand_name); CHECK_NE(instruction, nullptr); @@ -65,7 +65,7 @@ class HloDomainTest : public HloVerifiedTestBase { return false; } - StatusOr ParseModule(tensorflow::StringPiece hlo_string) { + StatusOr ParseModule(absl::string_view hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); ParseAndVerifyModule(hlo_string, config); @@ -80,10 +80,10 @@ class OpNameMetadata : public DomainMetadata { explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} std::unique_ptr Clone() const override { - return MakeUnique(opname_); + return absl::make_unique(opname_); } - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override { const OpNameMetadata* other_ptr = @@ -97,25 +97,28 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - static tensorflow::StringPiece KindName() { return "opname"; } + static absl::string_view KindName() { return "opname"; } + + size_t Hash() const override { return std::hash()(opname_); } private: string opname_; }; // Creator function for OpNameMetadata domains. -std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, - HloInstruction* operand) { - if (instruction->metadata().op_name() == operand->metadata().op_name()) { +HloInstruction* OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } std::unique_ptr operand_side_metadata = - MakeUnique(operand->metadata().op_name()); + absl::make_unique(root->metadata().op_name()); std::unique_ptr user_side_metadata = - MakeUnique(instruction->metadata().op_name()); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); + absl::make_unique(instruction->metadata().op_name()); + return operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, std::move(operand_side_metadata), + std::move(user_side_metadata))); } Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, @@ -142,7 +145,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -184,7 +187,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -211,7 +214,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -248,7 +251,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -302,7 +305,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator sharding_isolator(CreateShardingDomain); + HloDomainIsolator sharding_isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); @@ -344,7 +347,8 @@ ENTRY entry { token = token[] after-all() infeed = ((f32[4], f32[4]), token[]) infeed(token), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} - infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, + sharding={{maximal device=1}, {maximal device=0}} gte0 = f32[4] get-tuple-element(infeed.data), index=0 gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) @@ -356,7 +360,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -378,11 +382,8 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed = FindInstruction(module, "infeed"); - ASSERT_NE(infeed, nullptr); - HloInstruction* infeed_data = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( @@ -445,7 +446,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -474,8 +475,8 @@ ENTRY entry { TEST_F(HloDomainTest, DumpParseNullSharding) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {}); - auto sharding_md_0 = MakeUnique(nullptr); - auto sharding_md_1 = MakeUnique(nullptr); + auto sharding_md_0 = absl::make_unique(nullptr); + auto sharding_md_1 = absl::make_unique(nullptr); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( @@ -490,5 +491,203 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { ASSERT_TRUE(ParseModule(hlo_string).status().ok()); } +// Tuple inputs are domain instructions. +TEST_F(HloDomainTest, DomainTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=0} + cst = u32[] constant(0), sharding={maximal device=1} + tpl = (u32[], f32[4]) tuple(cst, p0), + sharding={{maximal device=1}, {maximal device=0}} + ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(ShardingDomainCreator{}); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tpl instruction, in order to test domain sharding + // application. + auto tpl = FindInstruction(module, "tpl"); + tpl->clear_sharding(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tpl->sharding()); +} + +TEST_F(HloDomainTest, MultiDomainMultiUser) { + const char* const hlo_string = R"( + HloModule Module + +ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { + %p0 = (f32[4], f32[4]) parameter(0) + %a = f32[4]{0} get-tuple-element(%p0), index=0 + %domain = f32[4] domain(%a), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %b = f32[4] get-tuple-element(%p0), index=1 + %domain.1 = f32[4] domain(%b), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1} + %domain.2 = f32[4] domain(%c), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %d = f32[4] subtract(%domain, %c), + sharding={maximal device=1}, metadata={op_name="D"} + %domain.3 = f32[4] domain(%d), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %e = f32[4] multiply(%c, %d), + sharding={maximal device=1}, metadata={op_name="D"} + %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1} + %domain.4 = f32[4]{0} domain(%f), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module)); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module)); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module)); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); +} + +// Emulate instructions inserted at top and bottom within nested tuple domain. +TEST_F(HloDomainTest, DomainTupleTopBottomInsert) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=1} + p1 = (f32[5], f32[6]) parameter(1), + sharding={{maximal device=1}, {maximal device=0}} + tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1), + sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}} + ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1, + sharding={{maximal device=1}, {maximal device=0}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(ShardingDomainCreator{}); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tuple.0 instruction, in order to test domain sharding + // application. + auto tuple0 = FindInstruction(module, "tuple.0"); + tuple0->clear_sharding(); + + // Insert the following instructons above and below tuple.0, to emulate other + // passes effects: + // COPY.0 + // \ / + // TUPLE.0 + // / \ + // COPY.1 \ + // / \ + // GTE.0 GTE.1 + // | | + // | COPY.2 + // \ / + // \ / + // TUPLE.1 + // | + auto tuple0_users = tuple0->users(); + auto computation = tuple0->parent(); + HloInstruction* copy0 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy, + tuple0->mutable_operand(1))); + TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0)); + + HloInstruction* copy1 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0)); + HloInstruction* gte0 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0)); + HloInstruction* gte1 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1)); + HloInstruction* copy2 = computation->AddInstruction( + HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1)); + HloInstruction* tuple1 = + computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2})); + + for (HloInstruction* user : tuple0_users) { + TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_TRUE(tuple0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tuple0->sharding()); + + EXPECT_TRUE(copy0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy0->sharding()); + + // copy1 has partial information only from gte.0, so in the end it gets no + // sharding at all. During propagation it does propagate the information from + // gte.0 though, enabling Tuple.0 to be fully sharded. + EXPECT_FALSE(copy1->has_sharding()); + + EXPECT_TRUE(gte0->has_sharding()); + EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding()); + + EXPECT_TRUE(gte1->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + gte1->sharding()); + + EXPECT_TRUE(copy2->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy2->sharding()); + + EXPECT_TRUE(tuple1->has_sharding()); + EXPECT_EQ(tuple0->sharding(), tuple1->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc index 751fc677e2d955fd3d9f8970f7c0370a22c054bf..dc514ae3e5c6907f6398805d171e69ee8635d08e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() { TF_RET_CHECK(instruction->user_side_metadata().Kind() == instruction->operand_side_metadata().Kind()) << instruction->ToString(); - kinds.insert(instruction->user_side_metadata().Kind().ToString()); + kinds.insert(string(instruction->user_side_metadata().Kind())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 8e53cf97f8ba9a88140a909ad20c1a938aec8c1f..bea5cba38d018029c9805e1593fadad54460447e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -29,11 +29,11 @@ namespace xla { // Verifies that the domain instructions are consistent, and the each domain is // surrounded by the same metadata. -class HloDomainVerifier : public HloPassInterface { +class HloDomainVerifier : public HloModulePass { public: HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} - tensorflow::StringPiece name() const override { return "domain_verifier"; } + absl::string_view name() const override { return "domain_verifier"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index c804f4364f6d16d5b8112219ce884495200aa827..72006e17e7e7ec09b62e88d05b695ec9f4c49647 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -144,13 +144,18 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); - if (!HasOperandType(hlo, eliminate_type_)) { + bool nullary = hlo->operands().empty(); + bool wrong_element_type = hlo->shape().element_type() == eliminate_type_; + bool should_eliminate_type = (nullary && wrong_element_type) || + HasOperandType(hlo, eliminate_type_); + if (!should_eliminate_type) { // If this CHECK fires, then this was an instruction that does not take // the elimination type as an operand but it does return it. This pass // does not have a feature to change the output type in that case, so diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 2b109225d0b192e5c9e4f6d841377ffad8078dc2..4d2a942925288ba4c3977ffcd25b55746a555a5e 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -25,16 +25,14 @@ namespace xla { // inserting Convert ops. This allows a backend to support an element type while // only actually implementing the Convert op for that element type. This is // generally not the fastest approach, but it works. -class HloElementTypeConverter : public HloPassInterface { +class HloElementTypeConverter : public HloModulePass { public: // eliminate_type is the type to eliminate as the input or output of ops, // using Convert ops to replace it with replace_with_type. HloElementTypeConverter(PrimitiveType eliminate_type, PrimitiveType replace_with_type); - tensorflow::StringPiece name() const override { - return "element_type_converter"; - } + absl::string_view name() const override { return "element_type_converter"; } // Returns the pass on the module and returns whether the module was modified. StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 51353eea6e72d5a131897f3c3ae312046051103e..d7c39b2778d57c1b2e9da0d87d9c2b91bb47e968 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -23,13 +23,15 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -52,12 +53,9 @@ namespace xla { namespace { -using tensorflow::gtl::ArraySlice; - template -StatusOr> Compare(const Shape& shape, HloOpcode opcode, - LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -95,19 +93,20 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } template <> -StatusOr> Compare( - const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -125,11 +124,12 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -138,49 +138,62 @@ StatusOr> Compare( HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); - typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); - typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[PRED] = + absl::make_unique>(this); + typed_visitors_[U8] = + absl::make_unique>(this); + typed_visitors_[U16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); + }); + typed_visitors_[U32] = + absl::make_unique>(this); + typed_visitors_[U64] = + absl::make_unique>(this); + typed_visitors_[S8] = absl::make_unique>(this); + typed_visitors_[S16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); + }); + typed_visitors_[S32] = + absl::make_unique>(this); + typed_visitors_[S64] = + absl::make_unique>(this); typed_visitors_[F16] = - MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + absl::make_unique>(this); + typed_visitors_[F32] = + absl::make_unique>(this); + typed_visitors_[F64] = + absl::make_unique>(this); + typed_visitors_[C64] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = - MakeUnique>(this); - - typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); - }); - typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); - }); + absl::make_unique>(this); + + typed_visitors_[TUPLE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); + }); } template -StatusOr> HloEvaluator::Evaluate( - const HloModule& module, ArraySlice arg_literals) { +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -192,12 +205,23 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .CloneToUnique(); + .Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(module, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, ArraySlice arg_literals) { +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -209,14 +233,23 @@ StatusOr> HloEvaluator::Evaluate( } TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(computation, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, ArraySlice arg_literals) { +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); evaluated_.clear(); arg_literals_.clear(); @@ -233,18 +266,27 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = input_literal->CloneToUnique(); + evaluated_[operand] = input_literal->Clone(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal : arg_literals) { + arg_literal_ptrs.push_back(&literal); + } + return Evaluate(instruction, arg_literal_ptrs); } -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction) { +StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kParameter) { return tensorflow::errors::FailedPrecondition( "Cannot evaluate a parameter."); @@ -253,7 +295,6 @@ StatusOr> HloEvaluator::Evaluate( return tensorflow::errors::FailedPrecondition( "Not all operands are constants."); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); evaluated_.clear(); @@ -261,21 +302,22 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -std::unique_ptr HloEvaluator::TryEvaluate( - HloInstruction* instruction) { +bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { + CHECK(result != nullptr); auto result_or = Evaluate(instruction); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); - return nullptr; + return false; } - return result_or.ConsumeValueOrDie(); + *result = result_or.ConsumeValueOrDie(); + return true; } -StatusOr> HloEvaluator::EvaluateWithSubstitutions( +StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions) { @@ -286,7 +328,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( owned_operands.push_back(operand->Clone()); } else { owned_operands.push_back( - HloInstruction::CreateConstant(it->second->CloneToUnique())); + HloInstruction::CreateConstant(it->second->Clone())); } } @@ -303,12 +345,12 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( +StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), @@ -318,10 +360,10 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( +StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = - HloInstruction::CreateConstant(operand.CloneToUnique()); + HloInstruction::CreateConstant(operand.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); @@ -330,13 +372,14 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr> HloEvaluator::EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, +StatusOr HloEvaluator::EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); TF_ASSIGN_OR_RETURN( Shape dot_shape, @@ -344,7 +387,7 @@ StatusOr> HloEvaluator::EvaluateDotOp( std::unique_ptr cloned_instruction = HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), - dim_numbers); + dim_numbers, precision_config); return Evaluate(cloned_instruction.get()); } @@ -357,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = input_literal->CloneToUnique(); + evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -378,7 +421,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -407,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -423,7 +466,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { if (!ShapeUtil::ElementIsFloating(operand->shape())) { return InvalidArgument( "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type()).c_str()); + PrimitiveType_Name(operand->shape().element_type())); } switch (operand->shape().element_type()) { @@ -453,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { return Status::OK(); } +Status HloEvaluator::HandleReal(HloInstruction* real) { + auto operand = real->operand(0); + switch (operand->shape().element_type()) { + case BF16: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](bfloat16 elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex64 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F16: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](Eigen::half elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](float elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](double elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleImag(HloInstruction* imag) { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + return Status::OK(); +} + Status HloEvaluator::HandleCompare(HloInstruction* compare) { HloOpcode opcode = compare->opcode(); auto lhs = compare->operand(0); @@ -464,9 +562,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(compare->shape()), + ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -555,43 +653,41 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -// Returns an ShapeUtil::IndexIterationSpace that iterates over the output -// gather dimensions while keeping the rest of the output dimensions clamped to -// 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { int64 output_rank = output_shape.dimensions_size(); std::vector index_base(output_rank, 0); std::vector index_count; index_count.reserve(output_rank); for (int64 i = 0; i < output_rank; i++) { - bool is_output_gather_dim = - !c_binary_search(dim_numbers.output_window_dims(), i); - index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) - : 1); + bool is_output_batch_dim = + !absl::c_binary_search(dim_numbers.offset_dims(), i); + index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); } return {std::move(index_base), std::move(index_count), std::vector(output_rank, 1)}; } -// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( - int64 output_rank, ArraySlice window_bounds, +ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( + int64 output_rank, absl::Span slice_sizes, const GatherDimensionNumbers& dim_numbers) { std::vector index_base(output_rank, 0); std::vector index_count(output_rank, 1); - int64 window_bounds_idx = 0; + int64 slice_sizes_idx = 0; for (int64 i = 0; i < output_rank; i++) { bool is_output_window_dim = - c_binary_search(dim_numbers.output_window_dims(), i); + absl::c_binary_search(dim_numbers.offset_dims(), i); if (is_output_window_dim) { - while (c_binary_search(dim_numbers.elided_window_dims(), - window_bounds_idx)) { - window_bounds_idx++; + while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), + slice_sizes_idx)) { + slice_sizes_idx++; } - index_count[i] = window_bounds[window_bounds_idx++]; + index_count[i] = slice_sizes[slice_sizes_idx++]; } } @@ -599,30 +695,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( std::vector(output_rank, 1)}; } -// This functor computes the contribution of gather_indices to an input index +// This functor computes the contribution of start_indices to an input index // corresponding to an output index. That is, given an output index I, it picks -// out the gather output indices in I and uses them to look up a gather index, -// G, from the gather indices tensor, and expands G into the input space -// according to gather_dims_to_operand_dims. -class OutputGatherIndexToInputIndex { +// out the batch indices in I and uses them to look up a starting index, G, from +// the start indices tensor, and expands G into the input space according to +// start_index_map. +class OutputBatchIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputGatherIndexToInputIndex( + explicit OutputBatchIndexToInputIndex( const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, - const Shape& output_shape, const Literal* gather_indices) - : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + const Shape& output_shape, const Literal* start_indices) + : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - output_dim_is_gather_dims_.push_back( - !c_binary_search(dim_numbers_.output_window_dims(), i)); + output_dim_is_batch_dims_.push_back( + !absl::c_binary_search(dim_numbers_.offset_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { int64 index_of_input_dim_in_index_vector = - std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), - c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + std::distance(dim_numbers_.start_index_map().begin(), + absl::c_find(dim_numbers_.start_index_map(), i)); if (index_of_input_dim_in_index_vector == - dim_numbers_.gather_dims_to_operand_dims_size()) { + dim_numbers_.start_index_map_size()) { input_dim_value_to_index_vector_.push_back(-1); } else { input_dim_value_to_index_vector_.push_back( @@ -630,14 +726,14 @@ class OutputGatherIndexToInputIndex { } } - index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + index_vector_index_.resize(start_indices_.shape().dimensions_size()); input_index_.resize(input_shape.dimensions_size()); int64 index_vector_size = - gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); index_vector_.resize(index_vector_size); } - // Returns the contribution of gather_indices to the input index corresponding + // Returns the contribution of start_indices to the input index corresponding // to output_index. See gather_inner_loop_body. // // This is conceptually a stateless transformation from output_index to the @@ -650,24 +746,25 @@ class OutputGatherIndexToInputIndex { // index_vector_index_ and index_vector on every invocation, we reuse the // same storage for all invocations. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return ArraySlice(input_index_); + return absl::Span(input_index_); } private: - // Propagates the gather index dimensions from the output index into + // Propagates the batch dimensions from the output index into // index_vector_index_ by mutating index_vector_index_ in place. Does not // update the dim_numbers.index_vector_dim() dimension -- that's the dimension // we iterate over in FetchIndexVector. void PropagateOutputIndexGatherDimsToIndexVectorIndex( - ArraySlice output_index) { + absl::Span output_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = output_index.size(); i < e; i++) { - if (!output_dim_is_gather_dims_[i]) { + if (!output_dim_is_batch_dims_[i]) { continue; } @@ -679,14 +776,14 @@ class OutputGatherIndexToInputIndex { } } - // Populates index_vector_ by iterating over gather_indices_ according to + // Populates index_vector_ by iterating over start_indices_ according to // index_vector_index_. Status FetchIndexVector() { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( - index_vector_index_)); + TF_ASSIGN_OR_RETURN(index_vector_[i], + start_indices_.GetIntegralAsS64(index_vector_index_)); } return Status::OK(); } @@ -708,40 +805,39 @@ class OutputGatherIndexToInputIndex { // PropagateIndexVectorToInputIndex. std::vector input_dim_value_to_index_vector_; - // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // output_dim_is_batch_dims_[i] is true iff the output index i is a gather // dimension. - std::vector output_dim_is_gather_dims_; + std::vector output_dim_is_batch_dims_; - // The buffer into which we construct an index into gather_indices_ to fetch + // The buffer into which we construct an index into start_indices_ to fetch // the index vector. std::vector index_vector_index_; - // The index vector fetched from gather_indices_. + // The index vector fetched from start_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; const GatherDimensionNumbers& dim_numbers_; - const Literal& gather_indices_; + const Literal& start_indices_; }; -// This functor computes the contribution of the window indices in an output +// This functor computes the contribution of the offset indices in an output // index to an input index. That is, given an output index I it picks out the -// output window indices in I and expands it into a window index into the input -// shape. -class OutputWindowIndexToInputIndex { +// output offset indices in I and expands it into an index into the input shape. +class OutputOffsetIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputWindowIndexToInputIndex( + explicit OutputOffsetIndexToInputIndex( const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, const Shape& output_shape) { std::vector window_index_to_output_index; int64 output_index_count = 0; for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.output_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.offset_dims(), i)) { window_index_to_output_index.push_back(output_index_count++); } else { output_index_count++; @@ -750,7 +846,7 @@ class OutputWindowIndexToInputIndex { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); } else { input_dim_value_to_output_index_.push_back( @@ -769,10 +865,11 @@ class OutputWindowIndexToInputIndex { // gather input index on every invocation we reuse the same storage for the // result (input_index_), mutating it in place. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); - return ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding output dimension index, @@ -785,7 +882,7 @@ class OutputWindowIndexToInputIndex { // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. void PropagateOutputIndexWindowDimsToInputIndex( - ArraySlice output_index) { + absl::Span output_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_output_index_[i] != -1) { input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; @@ -801,119 +898,117 @@ class OutputWindowIndexToInputIndex { // PropagateOutputIndexWindowDimsToInputIndex. std::vector input_dim_value_to_output_index_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; }; // Rehapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if -// there is one) to `reshaped_gather_indices`. +// there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( - int64 index_vector_dim, const Literal& gather_indices, - std::unique_ptr* reshaped_gather_indices) { - if (gather_indices.shape().dimensions_size() != index_vector_dim) { - return std::cref(gather_indices); + int64 index_vector_dim, const Literal& start_indices, + Literal* reshaped_start_indices) { + if (start_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(start_indices); } - std::vector new_shape(gather_indices.shape().dimensions().begin(), - gather_indices.shape().dimensions().end()); + std::vector new_shape(start_indices.shape().dimensions().begin(), + start_indices.shape().dimensions().end()); new_shape.push_back(1); - TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, - gather_indices.Reshape(new_shape)); - return std::cref(**reshaped_gather_indices); + TF_ASSIGN_OR_RETURN(*reshaped_start_indices, + start_indices.Reshape(new_shape)); + return std::cref(*reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { - std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + Literal result = Literal::CreateFromShape(gather->shape()); const Shape& shape = gather->shape(); const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_gather_indices; + Literal reshaped_start_indices; TF_ASSIGN_OR_RETURN( - const Literal& gather_indices, + const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), GetEvaluatedLiteralFor(gather->operand(1)), - &reshaped_gather_indices)); + &reshaped_start_indices)); // We iterate over the gather dimensions in the output shape in an outer loop // nest, and iterate over the window dimensions in the output shape in an // inner loop nest. - ShapeUtil::IndexIterationSpace gather_indices_iteration_space = - IterationSpaceForOutputGatherIndices(shape, dim_numbers); - ShapeUtil::IndexIterationSpace window_indices_iteration_space = - IterationSpaceForOutputWindowIndices( - shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + ShapeUtil::IndexIterationSpace start_indices_iteration_space = + IterationSpaceForOutputBatchIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace offset_indices_iteration_space = + IterationSpaceForOutputOffsetIndices( + shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers); // Scratch buffers that hold an index in the output shape and the // corresponding index in the input shape. std::vector input_index(operand.shape().dimensions_size()); std::vector output_index(gather->shape().dimensions_size()); - std::vector input_gather_index_clamped( - operand.shape().dimensions_size()); + std::vector input_index_clamped(operand.shape().dimensions_size()); - OutputGatherIndexToInputIndex output_gather_index_to_input_index( + OutputBatchIndexToInputIndex output_batch_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), - /*output_shape=*/shape, &gather_indices); - OutputWindowIndexToInputIndex output_window_index_to_input_index( + /*output_shape=*/shape, &start_indices); + OutputOffsetIndexToInputIndex output_offset_index_to_input_index( gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), /*output_shape=*/shape); const Shape& operand_shape = operand.shape(); auto gather_inner_loop_body = - [&](ArraySlice output_window_index, - ArraySlice input_gather_index, - ArraySlice output_gather_index) -> StatusOr { + [&](absl::Span output_window_index, + absl::Span input_gather_index, + absl::Span output_gather_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - ArraySlice input_window_index, - output_window_index_to_input_index(output_window_index)); + absl::Span input_window_index, + output_offset_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } for (int i = 0, e = input_gather_index.size(); i < e; i++) { int64 output_dim = - output_window_index_to_input_index.input_dim_value_to_output_index(i); + output_offset_index_to_input_index.input_dim_value_to_output_index(i); // If 'output_dim' is -1, it means 'i' is an elided window dim. This means // we set the iteration index to 0, so for the purpose of the following // calculations we can consider the output dimension size to be 1. int64 output_dim_size = output_dim == -1 ? 1 : shape.dimensions(output_dim); // Clamp the gather index so that the gather region fits in the operand. - // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // input_index_clamped[i] = clamp(input_gather_index[i], 0, // operand_shape.dimensions(i) - // output_dim_size); - input_gather_index_clamped[i] = + input_index_clamped[i] = std::min(operand_shape.dimensions(i) - output_dim_size, std::max(0LL, input_gather_index[i])); } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + input_index[i] = input_index_clamped[i] + input_window_index[i]; DCHECK_GE(input_index[i], 0); DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( - result->CopyElementFrom(operand, input_index, output_index)); + result.CopyElementFrom(operand, input_index, output_index)); return true; }; auto gather_outer_loop_body = - [&](ArraySlice output_gather_index) -> StatusOr { - TF_ASSIGN_OR_RETURN( - ArraySlice input_gather_index, - output_gather_index_to_input_index(output_gather_index)); + [&](absl::Span output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, + output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, window_indices_iteration_space, + shape, offset_indices_iteration_space, std::bind(gather_inner_loop_body, std::placeholders::_1, input_gather_index, output_gather_index))); return true; }; TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, gather_indices_iteration_space, gather_outer_loop_body)); + shape, start_indices_iteration_space, gather_outer_loop_body)); evaluated_[gather] = std::move(result); return Status::OK(); } @@ -929,8 +1024,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand.shape().dimensions(i)); + auto operand_dim_size = operand.shape().dimensions(i); + auto broadcast_dim_size = + broadcast->shape().dimensions(broadcast->dimensions(i)); + TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( + "Operand dimension %d is broadcast to output dimension %d, but the " + "sizes of these two dims do not match (%d vs %d): %s", + i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, + broadcast->ToString()); } TF_ASSIGN_OR_RETURN( @@ -960,18 +1061,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = MakeUnique( - ShapeUtil::GetTupleElementShape(operand->shape(), index)); - return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, - /*dest_shape_index=*/{}, - /*src_shape_index=*/{index}); + evaluated_[get_tuple_element] = + Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - - auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); - evaluated_[copy] = std::move(result); + evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); return Status::OK(); } @@ -987,7 +1086,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); @@ -1019,7 +1118,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator .Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); @@ -1039,7 +1138,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; - std::unique_ptr result; + Literal result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -1064,9 +1163,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { - evaluated_[select] = on_true.CloneToUnique(); + evaluated_[select] = on_true.Clone(); } else { - evaluated_[select] = on_false.CloneToUnique(); + evaluated_[select] = on_false.Clone(); } return Status::OK(); } @@ -1080,9 +1179,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); if (pred.Get({})) { - evaluated_[tuple_select] = on_true.CloneToUnique(); + evaluated_[tuple_select] = on_true.Clone(); } else { - evaluated_[tuple_select] = on_false.CloneToUnique(); + evaluated_[tuple_select] = on_false.Clone(); } return Status::OK(); } @@ -1091,23 +1190,23 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); // Initialize the loop carried valued with the input to the While instruction. - auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); HloEvaluator loop_body_evaluator(max_loop_iterations_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", - while_hlo->name().c_str(), max_loop_iterations_); + return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", + while_hlo->name(), max_loop_iterations_); } - TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( - *cond_comp, {lcv.get()})); - keep_going = cond_val->GetFirstElement(); + TF_ASSIGN_OR_RETURN(auto cond_val, + cond_evaluator.Evaluate(*cond_comp, {&lcv})); + keep_going = cond_val.GetFirstElement(); if (keep_going) { TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {lcv.get()})); - VLOG(3) << "Loop iteration result: " << body_val->ToString(); + *body_comp, {&lcv})); + VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); loop_body_evaluator.ResetVisitStates(); @@ -1122,96 +1221,100 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { // hoops to make this work. namespace { template -StatusOr> EvaluateSortInternal( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { auto rank = ShapeUtil::Rank(keys_literal.shape()); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " - << rank; TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; - // We need to sort and array of keys and an array of values, where the + // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the // array using the keys, and then go back to pair-of-arrays. VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - auto sort_r1 = [](const Literal& keys_literal, - const Literal& values_literal) { - const auto& keys_data = keys_literal.data(); - const auto& values_data = values_literal.data(); - - using kv_pair = std::pair; - std::vector key_value_vector; - CHECK_EQ(keys_data.size(), values_data.size()); - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); - } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - std::vector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - auto result_keys_literal = MakeUnique(keys_literal.shape()); - result_keys_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_keys)); - auto result_values_literal = MakeUnique(values_literal.shape()); - result_values_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_values)); - return std::make_pair(std::move(result_keys_literal), - std::move(result_values_literal)); - }; - - std::unique_ptr result_tuple; - if (rank == 1) { - auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = LiteralUtil::MakeTuple( - {result_pair.first.get(), result_pair.second.get()}); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - auto keys_result_literal = MakeUnique(keys_literal.shape()); - auto values_result_literal = MakeUnique(values_literal.shape()); - int64 r1_length = keys_literal.shape().dimensions(1); - for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto keys_r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - TF_ASSIGN_OR_RETURN(auto values_r1_slice, - values_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); - TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first->Reshape({1, r1_length})); - TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( - *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( - *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); - } - result_tuple = LiteralUtil::MakeTuple( - {keys_result_literal.get(), values_result_literal.get()}); + if (rank == 0) { + // Nothing to sort. + return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); } - VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys_literal.shape(), zero_base, + AsInt64Slice(keys_literal.shape().dimensions()), increment, + [&](absl::Span indices) -> StatusOr { + // Extract a slice from the keys and values literals that correspond to + // exactly the row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto keys_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& keys_data = keys_to_sort.data(); + TF_ASSIGN_OR_RETURN(auto values_to_sort, + values_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& values_data = values_to_sort.data(); + using kv_pair = std::pair; + std::vector key_value_vector; + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys; + std::vector result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + Literal sorted_keys(ShapeUtil::MakeShape( + keys_literal.shape().element_type(), {sort_dim_elements})); + sorted_keys.PopulateR1(absl::Span(result_keys)); + Literal sorted_values(ShapeUtil::MakeShape( + values_literal.shape().element_type(), {sort_dim_elements})); + sorted_values.PopulateR1(absl::Span(result_values)); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + std::vector start_indices(rank, 0); + TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, + sorted_keys.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys_reshaped, start_indices, indices, slice_dimensions)); + TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, + sorted_values.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + + Literal result_tuple; + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } template -StatusOr> EvaluateSortCurried( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortCurried(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { case F32: return EvaluateSortInternal(sort, keys_literal, @@ -1230,9 +1333,9 @@ StatusOr> EvaluateSortCurried( } } -StatusOr> EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(0)->shape().element_type()) { case F32: return EvaluateSortCurried(sort, keys_literal, values_literal); @@ -1249,15 +1352,6 @@ StatusOr> EvaluateSort(HloInstruction* sort, } // namespace Status HloEvaluator::HandleSort(HloInstruction* sort) { - const int64 sort_dim = sort->dimensions(0); - const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); - if (sort_dim != rank - 1) { - return Unimplemented( - "Trying to support along dimension %lld, which is not the last " - "dimension", - sort_dim); - } - if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { @@ -1272,40 +1366,49 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { } } +Status HloEvaluator::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsTuple(reduce->shape())) { + return DefaultAction(reduce); + } else { + auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); + for (const auto& tuple_shape : reduce->shape().tuple_shapes()) { + if (tuple_shape.element_type() != first_element_type) { + return Unimplemented( + "Reduce with several outputs that have mixed element types is " + "unsupported"); + } + } + return reduce->Visit(typed_visitors_.at(first_element_type).get()); + } +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); - return Status::OK(); + return ShapeUtil::ValidateShape(hlo->shape()); } Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> -HloEvaluator::Evaluate(const HloModule& module, - ArraySlice arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - const HloModule& module, ArraySlice> arg_literals); - -template StatusOr> -HloEvaluator::Evaluate(const HloComputation& computation, - ArraySlice arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( +template StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals); + +template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - ArraySlice> arg_literals); - -template StatusOr> -HloEvaluator::Evaluate(HloInstruction* instruction, - ArraySlice arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - HloInstruction* instruction, - ArraySlice> arg_literals); + absl::Span arg_literals); + +template StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index a4c37ef32827892194da070ee05ec6dc4f4c306f..6c2662ebaeff5ff3ae21b19fac430c3490e22d36 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,7 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -47,12 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -70,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -83,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + StatusOr Evaluate(HloInstruction* instruction, + absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. // Precondition: // 1. all operands of the input instruction are constants. // 2. the instruction is not a Parameter operation. - StatusOr> Evaluate(HloInstruction* instruction); + StatusOr Evaluate(HloInstruction* instruction); - // Same as Evaluate, except returning nullptr on error. - std::unique_ptr TryEvaluate(HloInstruction* instruction); + // Same as Evaluate, except returning false on error and accepts an output + // pointer. + bool TryEvaluate(HloInstruction* instruction, Literal* result); // Evaluates a single HLO instruction, substituting the given literals for // some of the instruction's operands. // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr> EvaluateWithSubstitutions( + StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions); - StatusOr> EvaluateElementwiseBinaryOp( - HloOpcode opcode, const Literal& lhs, const Literal& rhs); + StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr> EvaluateElementwiseUnaryOp( - HloOpcode opcode, const Literal& operand); + StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, - const Literal& rhs); + StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this @@ -185,6 +184,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSort(HloInstruction* sort) override; + Status HandleReal(HloInstruction* real) override; + + Status HandleImag(HloInstruction* imag) override; + + Status HandleReduce(HloInstruction* reduce) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -196,7 +201,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); + return it->second; } // Tracks the HLO instruction and its evaluated literal result. @@ -204,12 +209,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // that are no longer a parent for any other subsequent instruction in // post-orderring. // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; + // Storing Literal in place require the container to have pointer stability so + // we cannot use FlatMap any more. + std::unordered_map evaluated_; private: template - static StatusOr> ElementWiseUnaryOpImpl( + static StatusOr ElementWiseUnaryOpImpl( HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { @@ -222,13 +228,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); + ShapeUtil::HumanString(shape), + ShapeUtil::HumanString(operand->shape())); } - auto result = MakeUnique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cba72469ce73603f05d9957eb64e8519e8b8aec0..cee11a8a2166f96ae801095b6364921ed05d0000 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -51,12 +52,11 @@ static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: - HloEvaluatorTest() : use_bfloat16_(GetParam()) { - evaluator_ = MakeUnique(); + HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { + evaluator_ = absl::make_unique(); } - std::unique_ptr Evaluate( - tensorflow::gtl::ArraySlice arg_literals = {}) { + Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -66,41 +66,53 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, .ConsumeValueOrDie(); } + // Evaluate function that takes in a local module instead of using module_ + // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // removed, this should be the default Evaluate function. + Literal EvaluateWithModule( + HloModule* module, absl::Span arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(module).ValueOrDie(); + } + return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } + std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr input, float aabs = 0) { + void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, + float aabs = 0) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - b.AddInstruction( - HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto element_type = expected->shape().element_type(); + auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } } - void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr lhs, - std::unique_ptr rhs) { + void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs, + Literal rhs) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( - HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } bool use_bfloat16_; @@ -116,7 +128,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - Shape shape = low->shape(); + Shape shape = low.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -125,11 +137,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -137,7 +149,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); - Shape shape = value->shape(); + Shape shape = value.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -146,11 +158,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -160,7 +172,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - Shape shape = on_true->shape(); + Shape shape = on_true.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred))); auto c2 = @@ -171,11 +183,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -294,7 +306,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + std::vector args = {&lhs, &rhs, &rhs2}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -312,11 +324,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(args); + Literal result = Evaluate(args); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies Reshape operation is correctly evaluated. @@ -326,7 +338,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); @@ -336,14 +348,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - result->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { - std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); - }); + result.EachCell([&](absl::Span indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_NEAR(value, literal_clone.Get(rindexes), 0.031250); + }); } // Verifies Broadcast operation is correctly evaluated. @@ -355,12 +366,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, {1, 2})); + output_literal.shape(), literal_instruction, {1, 2})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -373,13 +384,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloInstruction::CreateConstant(std::move(input_literal))); // Broadcast dimension should be empty in the case of scalars. b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, + output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -397,11 +408,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -419,10 +430,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({100, 200}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -431,17 +442,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); - ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -451,17 +462,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); - ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } PaddingConfig CreatePaddingConfig( @@ -494,12 +505,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -521,9 +532,9 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto expected_array = MakeUnique>(8, 5, 1, 1); + auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); (*expected_array)(1, 0, 0, 0) = 1.0f; (*expected_array)(1, 2, 0, 0) = 2.0f; @@ -534,7 +545,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -547,7 +558,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -565,10 +576,10 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } - auto expected_array = MakeUnique>(1, 5); + auto expected_array = absl::make_unique>(1, 5); (*expected_array)(0, 0) = 7.0f; (*expected_array)(0, 1) = 2.718f; (*expected_array)(0, 2) = 2.718f; @@ -576,7 +587,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -588,7 +599,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -610,12 +621,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto expected_array = MakeUnique>(0, 9); + auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -628,7 +639,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // { 3 }, // { 4 }, // } - auto lhs_array = MakeUnique>(4, 1); + auto lhs_array = absl::make_unique>(4, 1); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -645,10 +656,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -660,7 +672,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -679,7 +691,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -690,14 +702,15 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -710,7 +723,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto lhs_array = MakeUnique>(4, 3); + auto lhs_array = absl::make_unique>(4, 3); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -722,7 +735,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -733,10 +746,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -746,7 +760,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -784,17 +798,18 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { dnums.set_kernel_input_feature_dimension(1); dnums.add_kernel_spatial_dimensions(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -838,12 +853,13 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -856,7 +872,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -921,22 +937,23 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); - Array4D expected_array_bf16({{{{2512, 2672}}}}); + Array4D expected_array_bf16({{{{2512, 2688}}}}); // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -998,22 +1015,23 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); - Array4D expected_array_bf16({{{{2512, 2672}}}}); + Array4D expected_array_bf16({{{{2512, 2688}}}}); // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1057,12 +1075,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1076,7 +1095,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1120,12 +1139,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1140,7 +1160,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, @@ -1191,12 +1211,13 @@ TEST_P(HloEvaluatorTest, ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1212,7 +1233,68 @@ TEST_P(HloEvaluatorTest, })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { + HloComputation::Builder b(TestName()); + std::vector input_dims = {1, 2, 2, 4}; + std::vector filter_dims = {2, 2, 2, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + std::iota(input_elems.begin(), input_elems.end(), -7); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + std::iota(filter_elems.begin(), filter_elems.end(), -31); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, + /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + Array4D expected_array(1, 1, 1, 8); + expected_array.FillWithYX( + Array2D({{668, 664, 660, 656, 668, 680, 692, 704}})); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1245,9 +1327,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { module().AddEntryComputation(b.Build()); HloEvaluator hlo_eval; - std::unique_ptr result = - hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR0Equal(kNumElements, *result); + Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, result); } // Reducing many numbers should be fast because it doesn't create @@ -1297,7 +1378,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1324,11 +1405,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({6, 18}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1339,7 +1420,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1376,10 +1457,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{6, 7}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1390,7 +1471,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1433,10 +1514,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1444,7 +1525,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = @@ -1494,12 +1575,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; - std::unique_ptr result_literal = + Literal result_literal = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); - EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1511,7 +1592,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // { 9, 10, 11, 12, 13 }, // { 17, 18, 19, 20, 21 }, // } - auto operand_array = MakeUnique>(3, 5); + auto operand_array = absl::make_unique>(3, 5); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1526,14 +1607,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*strides=*/{2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1544,7 +1625,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1560,14 +1641,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1580,7 +1661,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1596,14 +1677,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1614,7 +1695,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1633,14 +1714,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { shape, operand, update, start_indices)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1651,7 +1732,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal2 = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1669,14 +1750,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1687,7 +1768,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( @@ -1708,16 +1789,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); - auto expected = LiteralUtil::MakeTuple({ - result_inner_literal.get(), - result_inner_literal.get(), - }); + auto expected = + LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1748,7 +1827,7 @@ TEST_P(HloEvaluatorTest, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -1770,7 +1849,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1786,12 +1865,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; + Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, - {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, ¶m0_literal}, {square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1811,11 +1891,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; - auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); + auto result = + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1826,21 +1907,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1851,21 +1931,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1876,22 +1955,21 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3( + LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1902,23 +1980,22 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, @@ -1930,23 +2007,22 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1957,21 +2033,19 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()}))); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1982,21 +2056,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2007,20 +2080,18 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()}))); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2031,21 +2102,554 @@ ENTRY main { operand = s32[3] parameter(0) indices = s32[2,2,1] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1} + slice_sizes={1} } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr gather_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), + Evaluate({&operand, &start_indices}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), + Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + Literal expected = + LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // + {{-40, 40}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, + EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + Literal expected = + LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + Literal expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + Literal expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + operand, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + Literal expected = LiteralUtil::CreateR1({10, 61, 32}); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_NegativeIndices + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the negative indices. + Literal scatter_indices = LiteralUtil::CreateR1({-1, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { + const string hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the OOB indices. + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}}); + Literal updates = LiteralUtil::CreateR3({{{-10, 10}, {-40, 40}}}); + // Given the update window size of 2,2 and the index of 0,2, the update window + // will be OOB. So, nothing should be updated. + Literal expected = operand.Clone(); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2064,6 +2668,49 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } +TEST_P(HloEvaluatorTest, Bf16Reduction) { + const string hlo_text = R"( +HloModule Bf16Reduction + +add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs) +} + +ENTRY main { + arg0 = bf16[4]{0} parameter(0) + init = bf16[] constant(0) + ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR1( + {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); + Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); +} + +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index d1ee4a180be622523da13eb670a491fbd3dce23b..b2d12c94b848e4fd8ae473fdc0e4a9f5fecf6286 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,11 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -36,7 +43,9 @@ template using is_complex64_t = std::is_same; // It's UB to use std::sort with std::less, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. +// "safe" less functions which are actually strict weak orders. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -44,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return a < b; } -template ::value || - std::is_same::value>::type* = nullptr> +template ::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; + bool lhs_is_negative = std::signbit(a); + bool rhs_is_negative = std::signbit(b); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(a); + bool rhs_nan = std::isnan(b); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; } + return a < b; } -template ::value>::type* = nullptr> +template ::value || + std::is_same::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } + return SafeLess(static_cast(a), static_cast(b)); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -73,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. +// - HandleImag and HandleReal: where the resulting literal type is always float +// and the operand is always complex, or real in the case of HandleReal. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. // @@ -86,6 +104,29 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // of this class. template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + private: + // Get the value in the given literal static_cast as a double. + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + double GetAsDouble(const Literal& literal, + absl::Span input_index) { + return static_cast(literal.Get(input_index)); + } + + // Specialization for complex types. In this case it is not possible to + // static_cast value to a double so just CHECK fail. This method is not used + // at run-time, but must be available at compile-time to keep the compiler + // happy. + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + double GetAsDouble(const Literal& literal, + absl::Span input_index) { + LOG(FATAL) << "Trying to get complex literal as double: " + << literal.ToString(); + } + public: explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} @@ -117,7 +158,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } // TODO(b/35950897): many of the stl functions used in the handlers are not @@ -218,32 +259,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -301,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor(floor); } - Status HandleImag(HloInstruction* imag) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], - ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { - return std::imag(elem_operand); - })); - return Status::OK(); - } - Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -525,7 +547,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) override { + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleDivide(HloInstruction* divide) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { @@ -534,6 +560,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleDivide(HloInstruction* divide) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[divide], + ElementWiseBinaryOp( + divide, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT { + if (rhs_elem == 0) { + return static_cast(-1); + } + if (rhs_elem == -1 && + lhs_elem == std::numeric_limits::min()) { + return lhs_elem; + } + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleDivide(HloInstruction* divide) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return rhs_elem == 0 + ? std::numeric_limits::max() + : (lhs_elem / rhs_elem); + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) { + return HandleDivide(divide); + } + template ::value>::type* = nullptr> @@ -612,26 +678,51 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleReal(HloInstruction* real) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], - ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { - return std::real(elem_operand); + template ::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); })); return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleRemainder(HloInstruction* remainder) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return std::fmod(lhs_el, rhs_el); + return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el); })); return Status::OK(); } + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[remainder], + ElementWiseBinaryOp( + remainder, + [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT { + if (rhs_el == 0) { + return lhs_el; + } + if (rhs_el == -1 && + lhs_el == std::numeric_limits::min()) { + return 0; + } + return lhs_el % rhs_el; + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -873,10 +964,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = MakeUnique(result_shape); + Literal result(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -916,9 +1007,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -941,10 +1033,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); + int64 feature_group_count = conv->feature_group_count(); + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data]( - tensorflow::gtl::ArraySlice out_index) { + rhs_literal_data, + feature_group_count](const absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -955,7 +1049,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_batch_dim = dnums.output_batch_dimension(); const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 input_z_size = + ShapeUtil::GetDimension(lhs_shape, input_z_dim); + // The size of an input feature group. + const int64 input_feature_group_size = input_z_size / feature_group_count; + + const int64 output_z_size = + ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); + // The output feature dimension is a concatenation of convolution results + // from the different groups. + const int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current output index + // belongs. + const int64 feature_group_index = + out_index[output_z_dim] / output_feature_group_size; ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -963,7 +1072,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { - for (int64 iz = 0; iz < z_size; ++iz) { + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { + const int64 iz = + feature_group_index * input_feature_group_size + rhs_iz; + int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -972,7 +1084,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rhs_linear_index = 0; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { @@ -1025,13 +1137,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { static_cast(rhs_literal_data[rhs_linear_index]); } cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + } while (IndexUtil::BumpIndices(window_shape, + absl::MakeSpan(rhs_spatial_index))); return static_cast(result_val); }; - auto result = MakeUnique(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1078,7 +1191,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // result_index_locations[i] contains one or two pointers to the locations // in lhs_index or rhs_index where the i'th result index should go. - tensorflow::gtl::InlinedVector, kInlineRank> + absl::InlinedVector, kInlineRank> result_index_locations; result_index_locations.reserve(lhs_rank + rhs_rank - 2); @@ -1093,20 +1206,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { if (i != lhs_contracting_dimension && - !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { result_index_locations.push_back({&lhs_index[i], nullptr}); } } for (int64 i = 0; i < rhs_rank; i++) { if (i != rhs_contracting_dimension && - !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { result_index_locations.push_back({&rhs_index[i], nullptr}); } } - auto result = MakeUnique(dot->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { + Literal result(dot->shape()); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1153,24 +1266,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = MakeUnique(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate( + [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto func = [&](absl::Span input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1186,8 +1297,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set(target_index, - evaluated_operand.Get(input_index)); + result.Set(target_index, + evaluated_operand.Get(input_index)); return true; }; @@ -1314,16 +1425,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> MapImpl(HloInstruction* map) { + StatusOr MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = MakeUnique(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - std::vector> arg_literals; + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + std::vector arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1338,16 +1449,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get({}); + return computed_result.Get({}); })); return std::move(result); } @@ -1415,48 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - auto rank = ShapeUtil::Rank(keys->shape()); - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - - auto sort_r1 = [this](const Literal& keys_literal) { - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data(); - - std::vector result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess(a, b); - }); - auto result_literal = MakeUnique(keys_literal.shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); - return result_literal; - }; - - if (rank == 1) { - parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - auto result_literal = MakeUnique(keys_literal.shape()); - int64 r1_length = keys->shape().dimensions(1); - for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); - } - parent_->evaluated_[sort] = std::move(result_literal); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys->shape().dimensions(sort_dim); + int64 rank = ShapeUtil::Rank(keys->shape()); + if (rank == 0) { + // Nothing to sort. + parent_->evaluated_[sort] = keys_literal.Clone(); + return Status::OK(); } + Literal result_literal(keys_literal.shape()); + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), + increment, [&](absl::Span indices) -> StatusOr { + // Extract a slice from the literal that corresponds to exactly the + // row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto row_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& row_data = row_to_sort.data(); + + std::vector result_data(row_data.begin(), row_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); + Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), + {sort_dim_elements})); + sorted_row.PopulateR1(absl::Span(result_data)); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, + sorted_row.Reshape(slice_dimensions)); + std::vector start_indices(rank, 0); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + sorted_row_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + parent_->evaluated_[sort] = std::move(result_literal); return Status::OK(); } @@ -1472,16 +1588,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleSort(sort); } - Status HandleReduce(HloInstruction* reduce) override { - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + Status HandleReduce(HloInstruction* hlo) override { + HloReduceInstruction* reduce = Cast(hlo); + int64 num_args = reduce->inputs().size(); + bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); + + absl::InlinedVector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); + } TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - {&arg->shape(), &init_value->shape()}, + operand_shapes, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1489,14 +1609,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); + absl::InlinedVector arg_literals(num_args); + absl::InlinedVector init_literals(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]); + VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString(); + init_literals[i] = + &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]); + VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape())); + } - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + // All args and results have the same dimensions, so pick an arbitrary one. + const Shape& arg_shape = arg_literals[0]->shape(); + const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) + ? reduce->shape().tuple_shapes(0) + : reduce->shape(); + const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); std::vector arg_dim_steps(arg_dimensions.size()); std::vector arg_dim_counts(arg_dimensions.size()); for (const int64 dim : dimensions) { @@ -1514,61 +1643,107 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce->shape()); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = init_scalar; + absl::InlinedVector results(num_args); + for (int64 i = 0; i < num_args; ++i) { + results[i] = Literal(result_shape); + } - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } + Status eval_status; + // For each resulting dimension, calculate and assign computed values. + // This is really wasteful when num_args > 1, since we re-run the + // reduction num_args time. The alternative is to teach Populate() about + // tuples, which we should probably do. + absl::InlinedVector init_scalars(num_args); + for (int i = 0; i < num_args; ++i) { + init_scalars[i] = init_literals[i]->Get({}); + } - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](tensorflow::gtl::ArraySlice input_index) { - computed_result += arg_literal.Get(input_index); + for (int64 input = 0; input < num_args; ++input) { + TF_RETURN_IF_ERROR(results[input].Populate( + [&](absl::Span multi_index) { + if (!eval_status.ok()) { + return init_scalars[input]; + } + absl::InlinedVector result_values(init_scalars.begin(), + init_scalars.end()); + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) && + IsScalarAdd(function)) { + CHECK_EQ(num_args, 1); + double computed_result = 0; + auto func = [&](absl::Span input_index) { + computed_result += + GetAsDouble(*arg_literals[0], input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base, + arg_dim_counts, arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = + [&](absl::Span input_index) -> StatusOr { + absl::InlinedVector arg_values(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_values[i] = arg_literals[i]->Get(input_index); + } + + // Evaluate computation with specified literal operands. + absl::InlinedVector embedded_operands; + for (ReturnT value : result_values) { + embedded_operands.push_back( + LiteralUtil::CreateR0(value)); + } + for (ReturnT value : arg_values) { + embedded_operands.push_back( + LiteralUtil::CreateR0(value)); + } + absl::InlinedVector embedded_operands_ptrs( + embedded_operands.size()); + std::transform(embedded_operands.begin(), embedded_operands.end(), + embedded_operands_ptrs.begin(), + [](Literal& literal) { return &literal; }); + + TF_ASSIGN_OR_RETURN(Literal computed_result, + embedded_evaluator.Evaluate( + *function, embedded_operands_ptrs)); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + if (!has_tuple_output) { + result_values[0] = computed_result.Get({}); + } else { + for (int64 i = 0; i < num_args; ++i) { + result_values[i] = computed_result.Get( + /*multi_index=*/{}, /*shape_index=*/{i}); + } + } return true; }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = [&](tensorflow::gtl::ArraySlice input_index) { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = LiteralUtil::CreateR0(curr_val); - auto result_val_literal = - LiteralUtil::CreateR0(result_val); - - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) - .ConsumeValueOrDie(); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + eval_status = ShapeUtil::ForEachIndexWithStatus( + arg_shape, base, arg_dim_counts, arg_dim_steps, func); + return result_values[input]; + })); + } + if (!has_tuple_output) { + parent_->evaluated_[reduce] = std::move(results[0]); + } else { + Literal tuple_result(reduce->shape()); + for (int64 i = 0; i < num_args; ++i) { + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); + } + parent_->evaluated_[reduce] = std::move(tuple_result); + } + return eval_status; } bool IsScalarAdd(HloComputation* computation) { @@ -1595,13 +1770,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = MakeUnique(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); + TF_RETURN_IF_ERROR(result.Populate( + [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1639,8 +1812,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // 2. Using the selected index, scatter value from `source` to result. We // do this by iterating through the window, and compare each index with // the selected index. - tensorflow::gtl::optional selected_val; - tensorflow::gtl::optional> selected_index; + absl::optional selected_val; + absl::optional> selected_index; IterateThroughWindow( window_shape, window, operand_literal.shape(), source_index, @@ -1650,15 +1823,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); + bool selected = !computed_result.Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1672,22 +1844,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr computed_result = + auto scattered = result.Get(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); + result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); } }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + } while ( + IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index))); parent_->evaluated_[select_and_scatter] = std::move(result); return Status::OK(); @@ -1731,10 +1904,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1750,18 +1923,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(curr_val); const auto result_val_literal = LiteralUtil::CreateR0(result_val); - std::unique_ptr computed_result = + Literal computed_result = embedded_evaluator .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get({}); + result_val = computed_result.Get({}); }); return result_val; @@ -1771,6 +1943,383 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + // Reshapes the scatter indices input to have a trailing degenerate `1` + // dimension if necessary. Hands over the ownership of the newly created + // literal (if there is one) to `reshaped_indices`. + StatusOr> ReshapedScatterIndices( + int64 index_vector_dim, const Literal& indices, + Literal* reshaped_indices) { + if (indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(indices); + } + + std::vector new_shape(indices.shape().dimensions().begin(), + indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); + return std::cref(*reshaped_indices); + } + + // Returns an ShapeUtil::IndexIterationSpace that iterates over the update + // scatter dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_scatter_dim = + !absl::c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_scatter_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // Return an ShapeUtil::IndexIterationSpace that iterates over the update + // window dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_window_dim = + absl::c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // This functor computes the contribution of scatter_indices to an input index + // corresponding to an update index. That is, given an update index I, it + // picks out the scatter indices in I and uses them to look up a scatter + // index, S, from the scatter indices tensor, and expands S into the input + // space according to scatter_dims_to_operand_dims. + // + // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateScatterIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateScatterIndexToInputIndex( + const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& updates_shape, const Literal* scatter_indices) + : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + update_dim_is_scatter_dims_.push_back( + !absl::c_binary_search(dim_numbers_.update_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i); + if (index_of_input_dim_in_index_vector == + dim_numbers_.scatter_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of scatter_indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // scatter input index, but: + // + // - Instead of allocating memory to represent the scatter input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span update_index) { + PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return absl::Span(input_index_); + } + + private: + // Propagates the scatter index dimensions from the update index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the + // dimension we iterate over in FetchIndexVector. + void PropagateUpdateIndexScatterDimsToIndexVectorIndex( + absl::Span update_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = update_index.size(); i < e; i++) { + if (!update_dim_is_scatter_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = update_index[i]; + } + } + + // Populates index_vector_ by iterating over scatter_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // update_dim_is_scatter_dims_[i] is true iff the update index i is a + // scatter dimension. + std::vector update_dim_is_scatter_dims_; + + // The buffer into which we construct an index into scatter_indices_ to + // fetch the index vector. + std::vector index_vector_index_; + + // The index vector fetched from scatter_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns a Span + // into this vector. + std::vector input_index_; + + const ScatterDimensionNumbers& dim_numbers_; + const Literal& scatter_indices_; + }; + + // This functor computes the contribution of the window indices in an update + // index to an input index. That is, given an update index I it picks out the + // update window indices in I and expands it into a window index into the + // input shape. + // + // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateWindowIndexToInputIndex( + const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& updates_shape) { + std::vector window_index_to_update_index; + int64 update_index_count = 0; + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + window_index_to_update_index.push_back(update_index_count++); + } else { + update_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_dim_value_to_update_index_.push_back(-1); + } else { + input_dim_value_to_update_index_.push_back( + window_index_to_update_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // window input index, but instead of allocating memory to represent the + // scatter input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span update_index) { + PropagateUpdateIndexWindowDimsToInputIndex(update_index); + return absl::Span(input_index_); + } + + // Returns for a given 'input_dim' the corresponding update dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_update_index(int64 input_dim) { + return input_dim_value_to_update_index_[input_dim]; + } + + private: + // Propagates window dimensions from the update index to input_index_ by + // mutating input_index_ in place. + void PropagateUpdateIndexWindowDimsToInputIndex( + absl::Span update_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_update_index_[i] != -1) { + input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the update index. See + // PropagateUpdateIndexWindowDimsToInputIndex. + std::vector input_dim_value_to_update_index_; + + // The result computed by this functor. operator() returns a Span + // into this vector. + std::vector input_index_; + }; + + Status HandleScatter(HloInstruction* scatter) override { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + const Literal& operand = + parent_->GetEvaluatedLiteralFor(scatter->operand(0)); + Literal reshaped_scatter_indices; + TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, + ReshapedScatterIndices(dim_numbers.index_vector_dim(), + parent_->GetEvaluatedLiteralFor( + scatter->operand(1)), + &reshaped_scatter_indices)); + const Literal& updates = + parent_->GetEvaluatedLiteralFor(scatter->operand(2)); + const Shape& updates_shape = updates.shape(); + const Shape& operand_shape = operand.shape(); + + ShapeUtil::IndexIterationSpace scatter_indices_iteration_space = + IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers); + + std::vector input_index(operand_shape.dimensions_size()); + std::vector update_index(updates_shape.dimensions_size()); + std::vector input_scatter_index_clamped( + operand_shape.dimensions_size()); + + UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( + &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape, &scatter_indices); + UpdateWindowIndexToInputIndex update_window_index_to_input_index( + scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape); + + // Initialize the result with the operand. This makes it easier to handle + // the updates even when the indices are repeated. + Literal result = operand.Clone(); + HloEvaluator embedded_evaluator; + auto scatter_inner_loop_body = + [&](absl::Span update_window_index, + absl::Span input_scatter_index, + absl::Span update_scatter_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + absl::Span input_window_index, + update_window_index_to_input_index(update_window_index)); + for (int i = 0, e = update_index.size(); i < e; i++) { + update_index[i] = update_scatter_index[i] + update_window_index[i]; + DCHECK_LT(update_index[i], updates_shape.dimensions(i)); + } + for (int i = 0, e = input_scatter_index.size(); i < e; i++) { + int64 update_dim = + update_window_index_to_input_index.input_dim_value_to_update_index( + i); + // If 'update_dim' is -1, it means 'i' is an elided window dim. This + // means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the update dimension size to + // be 1. + int64 update_dim_size = + update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); + // If any part of the update region is out-of-bounds, then do not + // perform any update on the input. + if ((input_scatter_index[i] < 0) || + (input_scatter_index[i] > + operand_shape.dimensions(i) - update_dim_size)) { + return true; + } + } + for (int i = 0, e = input_index.size(); i < e; i++) { + input_index[i] = input_scatter_index[i] + input_window_index[i]; + } + + auto result_value_literal = + LiteralUtil::CreateR0(result.Get(input_index)); + auto update_value_literal = + LiteralUtil::CreateR0(updates.Get(update_index)); + Literal updated_result = + embedded_evaluator + .Evaluate( + *scatter->to_apply(), + {&result_value_literal, &update_value_literal}) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on the + // same computation. + embedded_evaluator.ResetVisitStates(); + result.Set(input_index, updated_result.Get({})); + return true; + }; + + auto scatter_outer_loop_body = + [&](absl::Span update_scatter_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + absl::Span input_scatter_index, + update_scatter_index_to_input_index(update_scatter_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, window_indices_iteration_space, + [&](absl::Span update_window_index) { + return scatter_inner_loop_body( + update_window_index, input_scatter_index, update_scatter_index); + })); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, scatter_indices_iteration_space, + scatter_outer_loop_body)); + parent_->evaluated_[scatter] = std::move(result); + return Status::OK(); + } + Status HandleSlice(HloInstruction* slice) override { auto operand = slice->operand(0); const Shape& shape = slice->shape(); @@ -1785,7 +2334,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 rank = ShapeUtil::Rank(operand->shape()); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { operand_index[i] = @@ -1794,9 +2343,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); + Literal result(shape); + TF_RETURN_IF_ERROR(result.Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2001,11 +2549,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value || std::is_same::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = MakeUnique(iota->shape()); - auto data = result->data(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + std::vector data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result.Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template & window_count_index, + const absl::Span& window_count_index, const std::function&)>& f) { const int64 rank = ShapeUtil::Rank(base_shape); DimensionVector window_index(rank); @@ -2065,13 +2623,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!out_of_bound) { f(base_index); } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } while ( + IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index))); } template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2084,9 +2643,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = MakeUnique(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + Literal result(result_shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2100,12 +2659,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2113,15 +2672,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); - auto func = [&](tensorflow::gtl::ArraySlice update_index) { + auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - result->Set(result_index, - update_literal.Get(update_index)); + result.Set(result_index, + update_literal.Get(update_index)); return true; }; @@ -2134,7 +2693,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr> ElementWiseUnaryOp( + StatusOr ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -2147,7 +2706,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr> ElementWiseBinaryOp( + StatusOr ElementWiseBinaryOp( HloInstruction* instruction, const std::function& binary_op) { @@ -2162,18 +2721,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = MakeUnique(shape); + Literal result(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2182,7 +2740,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> ElementwiseTernaryOp( + StatusOr ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -2198,20 +2756,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape()), + ShapeUtil::HumanString(ehs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = MakeUnique(shape); + Literal result(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index c3ccbf0f0c75b569b49652807dea52faebdccc31..de3d7a167752f0de790585e50874dd6d2904bd37 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -49,7 +51,7 @@ std::unique_ptr CreateHloProfilePrinterData( size_t profile_counters_size = hlo_profile_index_map.total_count(); std::unique_ptr profile_printer_data = - MakeUnique(); + absl::make_unique(); profile_printer_data->set_profile_counters_size(profile_counters_size); profile_printer_data->mutable_computation_infos()->Reserve( hlo_profile_index_map.computation_count()); @@ -67,11 +69,11 @@ std::unique_ptr CreateHloProfilePrinterData( // The profile indices were computed deterministically in // HloProfileIndexMap::HloProfileIndexMap. - c_sort(computation_and_profile_idx_list, - [](const std::pair& left, - const std::pair& right) { - return left.second < right.second; - }); + absl::c_sort(computation_and_profile_idx_list, + [](const std::pair& left, + const std::pair& right) { + return left.second < right.second; + }); for (const auto& pair : computation_and_profile_idx_list) { CHECK_LT(pair.second, profile_counters_size); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index eba80c0f199f6224f4b46ac19af482c713585154..460ae2b5eca78659f86df1227e6a0a4e57508611 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::AllOf; using ::testing::ContainsRegex; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bfe83cabf1c1168ee966827b7186004b708ad387..13a74fd8a115c5dc9a9518b226dfee4445cc7180 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -26,6 +26,12 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -37,50 +43,25 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" -using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; -using ::tensorflow::gtl::nullopt; -using ::tensorflow::gtl::optional; -using ::tensorflow::io::JoinPath; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::StringReplace; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { namespace hlo_graph_dumper { namespace { -// Helpers for Printf and Appendf. -template -struct PrintfConvert { - const T& operator()(const T& t) const { return t; } -}; -template <> -struct PrintfConvert { - const char* operator()(const string& s) const { return s.c_str(); } -}; - -// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() -// on strings. -template -string Printf(const char* fmt, const Ts&... ts) { - return tensorflow::strings::Printf(fmt, PrintfConvert()(ts)...); -} -template -void Appendf(string* s, const char* fmt, const Ts&... ts) { - tensorflow::strings::Appendf(s, fmt, PrintfConvert()(ts)...); -} +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; +using tensorflow::Env; +using tensorflow::WriteStringToFile; +using tensorflow::io::JoinPath; // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? @@ -139,12 +120,23 @@ class NodeFilter { std::function filter_; }; +// We arbitrarily set this as the boundary between "large" and "small" +// instructions. +bool IsSmall(const HloInstruction* instr) { + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { + return true; + } + return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; +} + // Node color schemes, used by NodeColorAttributes. enum ColorScheme { kBlue, kBrown, kDarkBlue, kDarkGreen, + kDarkOrange, kDarkRed, kGray, kGreen, @@ -177,6 +169,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) { return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; case kDarkGreen: return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkOrange: + // This is more of a "medium" orange, made to look close to kOrange; + // there's probably room for a darker weight if desired. + return NodeColors{"filled", "#ffb74d", "#c88719", "black"}; case kDarkRed: return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; case kGray: @@ -209,17 +205,15 @@ NodeColors NodeColorsForScheme(ColorScheme color) { string NodeColorAttributes(ColorScheme color) { NodeColors node_colors = NodeColorsForScheme(color); - return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", - node_colors.style, node_colors.font_color, node_colors.stroke_color, - node_colors.fill_color); + return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. -string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", - ">", /*replace_all=*/true); +string HtmlLikeStringSanitize(absl::string_view s) { + return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}}); } // Tries to generates a human-readable one-word description of the given @@ -322,11 +316,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). class HloDotDumper { public: - HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + HloDotDumper(const HloComputation* computation, absl::string_view label, const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(std::string(label)), + label_(label), debug_options_(debug_options), show_backend_config_(show_backend_config), profile_(profile), @@ -448,7 +442,7 @@ string HloDotDumper::Dump() { } string HloDotDumper::Header() { - const char* fmt = R"(digraph G { + constexpr char fmt[] = R"(digraph G { rankdir = TB; compound = true; label = <%s>; @@ -457,7 +451,7 @@ labelloc = t; tooltip = " "; // DOT graphs accept a stylesheet as a URI. So naturally, an inline // stylesheet is a data URI! -stylesheet=" +stylesheet=< data:text/css, @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); svg text { @@ -466,7 +460,7 @@ stylesheet=" } %s -" +> )"; @@ -475,14 +469,13 @@ stylesheet=" string graph_label = StrCat(label_, "
Computation ", computation_->name()); if (computation_->IsFusionComputation()) { - StrAppend(&graph_label, - StrCat(" (in fusion instruction ", - computation_->FusionInstruction()->name(), ")")); + StrAppend(&graph_label, " (in fusion instruction ", + computation_->FusionInstruction()->name(), ")"); } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); - Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles)); + absl::StrAppendFormat(&graph_label, "
total cycles = %d (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); } // Create CSS rules that say, when you hover over the given node or cluster, @@ -509,14 +502,14 @@ stylesheet=" // One could imagine other ways of writing this CSS rule that involve // less duplication, but this way seems to be relatively performant. edge_css_rules.push_back( - Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" - " #%s%d:hover ~ #edge%lld path { " - "stroke: %s; stroke-width: .2em; }\n" - " #%s%d:hover ~ #edge%lld polygon { " - "fill: %s; stroke: %s; stroke-width: .2em; }\n", - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, color)); + StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n" + " #%s%d:hover ~ #edge%d path { " + "stroke: %s; stroke-width: .2em; }\n" + " #%s%d:hover ~ #edge%d polygon { " + "fill: %s; stroke: %s; stroke-width: .2em; }\n", + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, color)); }; // The "to_node" value may be a NULL, indicating that this points to the @@ -559,10 +552,10 @@ stylesheet=" } } - return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); + return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); } -string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } +string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { CHECK_EQ(instr->opcode(), HloOpcode::kFusion); @@ -600,9 +593,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() << " as " << next_edge_id_; edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = + constexpr char edge_fmt[] = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( + edges_.push_back(StrFormat( edge_fmt, InstructionId(from), InstructionId(parent_instr), SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } @@ -619,9 +612,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, string subcomp_label, style; if (parent_instr->opcode() == HloOpcode::kFusion) { - subcomp_label = Printf("Fused expression for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(parent_instr->ToCategory())); + subcomp_label = + StrFormat("Fused expression for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); string extra_info = GetInstructionNodeExtraInfo(parent_instr); if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); @@ -647,18 +641,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; } style = - Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", - fillcolor, strokecolor); + StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", + fillcolor, strokecolor); } else { - subcomp_label = Printf("Subcomputation for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(subcomp->name())); + subcomp_label = StrFormat("Subcomputation for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); style = "style=rounded; color=black;"; } string comp_body = DumpComputation(subcomp); - const char* computation_fmt = R"(subgraph %s { + constexpr char computation_fmt[] = R"(subgraph %s { %s label = <%s>; labelloc = t; @@ -667,7 +661,7 @@ tooltip = " "; } // %s )"; - return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -718,11 +712,11 @@ string HloDotDumper::DumpRootTag() { VLOG(2) << "Adding edge from " << from->name() << " to root tag as " << next_edge_id_; edge_ids_.insert({{from, to}, next_edge_id_++}); - edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); + edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" - "\n", - to_id, node_body, node_shape, NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + "\n", + to_id, node_body, node_shape, NodeColorAttributes(color)); } static const HloConstantInstruction* TryGetFusionParameterConstant( @@ -817,10 +811,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" - "\n", - InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" + "\n", + InstructionId(instr), node_body, node_shape, node_metadata, + NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -833,7 +827,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. if (ShapeUtil::IsZeroElementArray(shape)) { - return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. @@ -848,19 +842,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return Printf("%s (%s)", constant->literal().ToString(), - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s (%s)", constant->literal().ToString(), + ShapeUtil::HumanString(constant->shape())); } // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { + if (absl::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); } - return Printf("%s %s", constant_name, - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; std::vector lines; @@ -881,7 +875,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( TryGetFusionParameterConstant(operand)) { operand_str = stringify_constant(constant); } else { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + operand_str = StrFormat("Parameter %d", operand->parameter_number()); } } else { operand_str = operand->name(); @@ -890,13 +884,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( if (operand_str) { if (instr->operand_count() > 1) { - lines.push_back(Printf("operand %lld = %s", i, *operand_str)); + lines.push_back(StrFormat("operand %d = %s", i, *operand_str)); } else { - lines.push_back(Printf("operand = %s", *operand_str)); + lines.push_back(StrFormat("operand = %s", *operand_str)); } } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { @@ -913,7 +907,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { sharding_colors_.emplace(instr->sharding(), color); return color; } - const auto kParameterColor = kOrange; + + // Choose different weights of orange for small vs large parameters. This + // distinction is often important, especially in fusion nodes. + auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange; // Special case: If this instruction has a parameter merged into it, paint it // the same color as a parameter. Unless the merged-in parameter is a @@ -925,7 +922,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { ShouldMergeIntoUsers(operand) && TryGetFusionParameterConstant(operand) == nullptr; })) { - return kParameterColor; + return parameter_color; } // Pick different colors or shapes for instructions which are particularly @@ -1035,7 +1032,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kParameterColor; + return parameter_color; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1048,6 +1045,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: return kGray; case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: @@ -1058,7 +1057,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: - case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: @@ -1079,14 +1077,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // If we have a parameter, put the param number in the name. if (instr->opcode() == HloOpcode::kParameter) { - return Printf("Parameter %lld", instr->parameter_number()); + return StrFormat("Parameter %d", instr->parameter_number()); } // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::str_util::StartsWith(instr->name(), - HloOpcodeString(instr->opcode()))) { - return Printf("%s", HtmlLikeStringSanitize(instr->name())); + if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) { + return StrFormat("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = StrCat(HloOpcodeString(instr->opcode()), @@ -1094,8 +1091,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { ? "" : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), - HtmlLikeStringSanitize(instr->name())); + return StrFormat("%s
%s", HtmlLikeStringSanitize(extended_opcode), + HtmlLikeStringSanitize(instr->name())); } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { @@ -1104,16 +1101,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); } if (!instr->metadata().op_type().empty()) { - lines.push_back(Printf( + lines.push_back(StrFormat( "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type()))); } if (!instr->metadata().source_file().empty() && instr->metadata().source_line() != 0) { - lines.push_back(Printf("op_type: %s", instr->metadata().source_file(), - instr->metadata().source_line())); + lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(), + instr->metadata().source_line())); } - return Join(lines, "
"); + return StrJoin(lines, "\n"); } string HloDotDumper::GetInstructionNodeBackendConfig( @@ -1160,13 +1157,12 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { constexpr int kMaxShapeLen = 64; if (instr_shape.length() > kMaxShapeLen) { instr_shape = StrCat( - tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), - "..."); + absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "..."); } lines.push_back(instr_shape); } if (debug_options_.xla_hlo_graph_addresses()) { - lines.push_back(Printf("[%p]", instr)); + lines.push_back(StrFormat("[%p]", instr)); } if (profile_ != nullptr) { double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); @@ -1174,25 +1170,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { lines.push_back( - Printf("%% of cycles executed=%.2f", - 100 * hlo_cycles_executed / total_cycles_executed)); + StrFormat("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } } - return Join(lines, "
"); -} - -// Gets the total number of array elements in the given shape. For tuples, this -// is the sum of all the sizes of all of the array elements recursively in the -// tuple. -static int64 TotalElementsInShape(const Shape& shape) { - int64 elems = 0; - ShapeUtil::ForEachSubshape( - shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { - elems += ShapeUtil::ElementsIn(subshape); - } - }); - return elems; + return StrJoin(lines, "
"); } void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { @@ -1210,20 +1192,19 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { string edge_label; if (instr->operand_count() > 1 && !control_edge) { - edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); + edge_label = + StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num); } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } // We print "small" arrays using a hollow arrowhead and "large" arrays using - // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" - // means. - bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - - const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; - edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), from->name(), - to->name(), edge_label)); + // a filled arrowhead. + constexpr char kEdgeFmt[] = + R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; + edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), + (IsSmall(from) ? "empty" : "normal"), + from->name(), to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1264,14 +1245,14 @@ string HloDotDumper::GetInstructionTrivialComputationStr( continue; } if (instr->called_computations().size() == 1) { - lines.push_back(Printf("Subcomputation: %s", - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation: %s", + HtmlLikeStringSanitize(*computation_type))); } else { - lines.push_back(Printf("Subcomputation %lld: %s", i, - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation %d: %s", i, + HtmlLikeStringSanitize(*computation_type))); } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } const HloInstruction* HloDotDumper::GetNodeForEdge( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1d7a062c55696de9db4b187efd86bce191279083..064c53252c0ac4d4e7b93169ad7cbee4807cb963 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,12 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::HasSubstr; string TestName() { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 7591b992048e373d4b67bb7863af4eb4b7f65e11..f7ec854d800fdbc0af3a53eb3bebe772f432e478 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -21,10 +21,17 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -39,17 +46,15 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; /* static */ StatusOr> HloInstruction::CreateFromProto( @@ -108,7 +113,7 @@ StatusOr> HloInstruction::CreateFromProto( std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), - tensorflow::gtl::ArraySlice(fft_length)); + absl::Span(fft_length)); break; } case HloOpcode::kSend: @@ -153,16 +158,26 @@ StatusOr> HloInstruction::CreateFromProto( CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Reduce instruction should have 2 operands but sees " + TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) + << "Reduce instruction should have an even number of operands but " + "sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Reduce instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduce(proto.shape(), operands(0), operands(1), - std::vector(proto.dimensions().begin(), - proto.dimensions().end()), - computations(0)); + { + const auto reduce_operands = all_operands(); + auto inputs = absl::MakeSpan(reduce_operands) + .subspan(0, reduce_operands.size() / 2); + auto init_values = + absl::MakeSpan(reduce_operands) + .subspan(reduce_operands.size() / 2, reduce_operands.size()); + instruction = + CreateReduce(proto.shape(), inputs, init_values, + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + } break; case HloOpcode::kSort: { TF_RET_CHECK(proto.operand_ids_size() == 1 || @@ -224,7 +239,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = MakeUnique(proto.shape()); + instruction = absl::make_unique(proto.shape()); } break; } @@ -235,7 +250,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); - instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } case HloOpcode::kFusion: { @@ -281,55 +296,66 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - if (proto.operand_ids_size() == 0) { - // TODO(b/80000000): Remove this when all uses of infeed are - // converted to take tokens. - instruction = CreateInfeed(data_shape, proto.infeed_config()); - } else { - CHECK_EQ(proto.operand_ids_size(), 1); - instruction = - CreateInfeed(data_shape, operands(0), proto.infeed_config()); - } + TF_RET_CHECK(proto.operand_ids_size() == 1); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - if (proto.operand_ids_size() == 1) { - // TODO(b/80000000): Remove this when all uses of outfeed are - // converted to take tokens. - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - proto.outfeed_config()); - } else { - CHECK_EQ(proto.operand_ids_size(), 2); - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - operands(1), proto.outfeed_config()); - } + TF_RET_CHECK(proto.operand_ids_size() == 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); break; case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "CrossReplicaSum should have 1 called computation but sees " << proto.called_computation_ids_size(); - tensorflow::gtl::optional all_reduce_id; + absl::optional all_reduce_id; if (proto.all_reduce_id() > 0) { all_reduce_id = proto.all_reduce_id(); } instruction = CreateCrossReplicaSum( proto.shape(), all_operands(), computations(0), - /*replica_group_ids=*/ - std::vector(proto.replica_group_ids().begin(), - proto.replica_group_ids().end()), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), /*barrier=*/proto.cross_replica_sum_barrier(), /*all_reduce_id=*/all_reduce_id); break; } - case HloOpcode::kConvolution: + case HloOpcode::kAllToAll: { + instruction = CreateAllToAll( + proto.shape(), all_operands(), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end())); + break; + } + case HloOpcode::kCollectivePermute: { + std::vector> source_target_pairs( + proto.source_target_pairs_size()); + for (int i = 0; i < source_target_pairs.size(); i++) { + source_target_pairs[i].first = proto.source_target_pairs(i).source(); + source_target_pairs[i].second = proto.source_target_pairs(i).target(); + } + instruction = CreateCollectivePermute(proto.shape(), operands(0), + source_target_pairs); + break; + } + case HloOpcode::kConvolution: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); - instruction = - CreateConvolve(proto.shape(), operands(0), operands(1), - proto.window(), proto.convolution_dimension_numbers()); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = CreateConvolve( + proto.shape(), operands(0), operands(1), + std::max(proto.feature_group_count(), 1), proto.window(), + proto.convolution_dimension_numbers(), precision_config); break; + } case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) << "ReduceWindow instruction should have 2 operands but sees " @@ -363,11 +389,9 @@ StatusOr> HloInstruction::CreateFromProto( ->set_convolution_dimension_numbers( proto.convolution_dimension_numbers()); } - break; - case HloOpcode::kHostCompute: - instruction = - CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), - proto.cost_estimate_ns()); + static_cast(instruction.get()) + ->set_feature_group_count( + std::max(static_cast(proto.feature_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -382,7 +406,7 @@ StatusOr> HloInstruction::CreateFromProto( << "DynamicSlice instruction should have 2 operands but sees " << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); - c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), slice_sizes); break; @@ -394,14 +418,14 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr gather_dimension_numbers = - MakeUnique(proto.gather_dimension_numbers()); - std::vector gather_window_bounds; - for (int64 bound : proto.gather_window_bounds()) { - gather_window_bounds.push_back(bound); + absl::make_unique( + proto.gather_dimension_numbers()); + std::vector gather_slice_sizes; + for (int64 bound : proto.gather_slice_sizes()) { + gather_slice_sizes.push_back(bound); } - instruction = - CreateGather(proto.shape(), operands(0), operands(1), - *gather_dimension_numbers, gather_window_bounds); + instruction = CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_slice_sizes); break; } case HloOpcode::kScatter: { @@ -413,15 +437,44 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Scatter instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - auto scatter_dimension_numbers = MakeUnique( - proto.scatter_dimension_numbers()); + auto scatter_dimension_numbers = + absl::make_unique( + proto.scatter_dimension_numbers()); instruction = CreateScatter(proto.shape(), operands(0), operands(1), operands(2), computations(0), *scatter_dimension_numbers); break; } + case HloOpcode::kIota: + TF_RET_CHECK(proto.dimensions_size() <= 1) + << "Iota instruction should have at most 1 dimension but sees " + << proto.dimensions_size(); + instruction = CreateIota(proto.shape(), proto.dimensions(0)); + break; + case HloOpcode::kDot: { + TF_RET_CHECK(proto.has_dot_dimension_numbers()) + << "Dot instruction should have dot_dimension_numbers."; + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Dot instruction should have 2 operands but sees " + << proto.operand_ids_size(); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = absl::make_unique( + proto.shape(), operands(0), operands(1), + proto.dot_dimension_numbers(), precision_config); + break; + } + case HloOpcode::kDomain: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Domain instruction should have 1 operands but sees " + << proto.operand_ids_size(); + instruction = absl::make_unique( + proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, + /*user_side_metadata=*/nullptr); + break; default: { - instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) << "No instruction with id " << operand_id; @@ -441,6 +494,9 @@ StatusOr> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); break; } } @@ -449,11 +505,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - - if (proto.has_dot_dimension_numbers()) { - instruction->dot_dimension_numbers_ = - MakeUnique(proto.dot_dimension_numbers()); - } + instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -466,44 +518,46 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - return MakeUnique(parameter_number, shape, name); + return absl::make_unique(parameter_number, shape, + name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - return MakeUnique(tag, operand); + return absl::make_unique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( - std::unique_ptr literal) { - return MakeUnique(std::move(literal)); + Literal literal) { + return absl::make_unique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateIota( - const Shape& shape) { - return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + const Shape& shape, int64 iota_dimension) { + return absl::make_unique(shape, iota_dimension); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - return MakeUnique(shape, operand, index); + return absl::make_unique(shape, operand, + index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) { - return MakeUnique(shape, distribution, parameters); + absl::Span parameters) { + return absl::make_unique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); } - auto instruction = WrapUnique(new HloInstruction(opcode, shape)); + auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -522,7 +576,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -554,7 +607,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kGe: case HloOpcode::kGt: @@ -600,54 +652,40 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK_EQ(HloOpcode::kTuple, opcode); return CreateNary(shape, opcode, operands); } /* static */ std::unique_ptr HloInstruction::CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation) { - return MakeUnique(shape, operands, map_computation); + return absl::make_unique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { - return MakeUnique(shape, lhs, rhs, window, - dimension_numbers); + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { + return absl::make_unique( + shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { - return MakeUnique(shape, operand, fft_type, fft_length); + absl::Span fft_length) { + return absl::make_unique(shape, operand, fft_type, + fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - MakeUnique(dimension_numbers); - return instruction; -} - -/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = MakeUnique(); - instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); - instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); - return instruction; + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { + return absl::make_unique( + shape, lhs, rhs, dimension_numbers, precision_config); } /* static */ std::unique_ptr @@ -655,52 +693,55 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - return MakeUnique( + return absl::make_unique( shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id) { - return MakeUnique( - shape, operands, reduce_computation, replica_group_ids, barrier, + const std::vector& replica_groups, absl::string_view barrier, + const absl::optional& all_reduce_id) { + return absl::make_unique( + shape, operands, reduce_computation, replica_groups, barrier, all_reduce_id); } -/* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& infeed_shape, HloInstruction* token_operand, - const string& config) { - return MakeUnique(infeed_shape, token_operand, config); +/* static */ std::unique_ptr HloInstruction::CreateAllToAll( + const Shape& shape, absl::Span operands, + const std::vector& replica_groups) { + return absl::make_unique(shape, operands, + replica_groups); } -/* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& infeed_shape, const string& config) { - return MakeUnique(infeed_shape, config); +/* static */ std::unique_ptr +HloInstruction::CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs) { + return absl::make_unique( + shape, operand, source_target_pairs); } -/* static */ std::unique_ptr HloInstruction::CreateOutfeed( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - token_operand, outfeed_config); +/* static */ std::unique_ptr HloInstruction::CreateInfeed( + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return absl::make_unique(infeed_shape, token_operand, + config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config) { + return absl::make_unique( + outfeed_shape, operand, token_operand, outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(operand, token, channel_id, - is_host_transfer); + return absl::make_unique(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( @@ -708,14 +749,15 @@ HloInstruction::CreateCrossReplicaSum( auto send_operand = DynCast(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique(send_operand, is_host_transfer); + return absl::make_unique(send_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(shape, token, channel_id, - is_host_transfer); + return absl::make_unique(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( @@ -723,19 +765,20 @@ HloInstruction::CreateCrossReplicaSum( auto recv_operand = DynCast(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique(recv_operand, is_host_transfer); + return absl::make_unique(recv_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + absl::Span dimensions) { + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateAfterAll( - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK(!operands.empty()); - auto instruction = WrapUnique( + auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); @@ -744,14 +787,15 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateToken() { - return WrapUnique( + return absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); // Body comes before condition computation in the vector. instruction->called_computations_.push_back(body); @@ -764,7 +808,7 @@ HloInstruction::CreateCrossReplicaSum( HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); instruction->AppendOperand(pred); instruction->AppendOperand(true_computation_arg); instruction->AppendOperand(false_computation_arg); @@ -778,18 +822,17 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - return MakeUnique(shape, operand, start_indices, - limit_indices, strides); + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides) { + return absl::make_unique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - return MakeUnique(shape, operand, start_indices, - slice_sizes); + absl::Span slice_sizes) { + return absl::make_unique( + shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr @@ -797,8 +840,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(update); instruction->AppendOperand(start_indices); @@ -806,14 +849,16 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) { - return MakeUnique(shape, operands, dimension); + return absl::make_unique(shape, operands, + dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( const Shape& shape, HloInstruction* operand) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -822,38 +867,38 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction::CreateBitcastConvert(const Shape& shape, HloInstruction* operand) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); instruction->AppendOperand(operand); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloReduceInstruction( + auto instruction = absl::WrapUnique(new HloReduceInstruction( shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); return std::move(instruction); } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { std::vector all_args; all_args.reserve(operands.size() * 2); all_args.insert(all_args.end(), operands.begin(), operands.end()); all_args.insert(all_args.end(), init_values.begin(), init_values.end()); - return MakeUnique(shape, all_args, dimensions_to_reduce, - reduce_computation); + return absl::make_unique( + shape, all_args, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - return MakeUnique(shape, operand, init_value, - window, reduce_computation); + return absl::make_unique( + shape, operand, init_value, window, reduce_computation); } /* static */ std::unique_ptr @@ -862,7 +907,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, epsilon, feature_index); } @@ -871,7 +916,7 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, mean, variance, epsilon, feature_index); } @@ -881,9 +926,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - return MakeUnique(shape, operand, scale, mean, - variance, grad_output, epsilon, - feature_index); + return absl::make_unique( + shape, operand, scale, mean, variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -891,15 +936,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - return MakeUnique( + return absl::make_unique( shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return MakeUnique(shape, operand, - broadcast_dimensions); + absl::Span broadcast_dimensions) { + return absl::make_unique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -957,8 +1002,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - return MakeUnique(shape, operand, padding_value, - padding_config); + return absl::make_unique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -967,34 +1012,36 @@ HloInstruction::CreateBroadcastSequence( ShapeUtil::ElementsIn(operand->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operand->shape()); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + absl::Span dimensions) { + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, HloInstruction* values) { - return MakeUnique(shape, dimension, keys, values); + return absl::make_unique(shape, dimension, keys, values); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return MakeUnique(shape, fusion_kind, fused_root); + return absl::make_unique(shape, fusion_kind, + fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) { - return MakeUnique(shape, fusion_kind, operands, - fusion_computation); + return absl::make_unique(shape, fusion_kind, operands, + fusion_computation); } void HloInstruction::set_single_sharding(const HloSharding& sharding) { @@ -1026,7 +1073,6 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: - case HloOpcode::kHostCompute: return true; case HloOpcode::kCrossReplicaSum: return all_reduce_id().has_value(); @@ -1049,10 +1095,10 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation) { std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -1061,21 +1107,14 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { - return MakeUnique(shape, operands, - custom_call_target); -} - -/* static */ std::unique_ptr HloInstruction::CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - return MakeUnique(shape, operands, channel_name, - cost_estimate_ns); + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target) { + return absl::make_unique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (auto element : elements) { element_shapes.push_back(element->shape()); @@ -1085,11 +1124,11 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateGather( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - return MakeUnique(shape, operand, gather_indices, - gather_dim_numbers, window_bounds); + absl::Span slice_sizes) { + return absl::make_unique( + shape, operand, start_indices, gather_dim_numbers, slice_sizes); } /* static */ std::unique_ptr HloInstruction::CreateScatter( @@ -1097,25 +1136,22 @@ bool HloInstruction::HasSideEffect() const { HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers) { - return MakeUnique(shape, operand, scatter_indices, - updates, update_computation, - scatter_dim_numbers); + return absl::make_unique( + shape, operand, scatter_indices, updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); - instruction->operand_side_metadata_ = std::move(operand_side_metadata); - instruction->user_side_metadata_ = std::move(user_side_metadata); - instruction->AppendOperand(operand); - return instruction; + return absl::make_unique( + shape, operand, std::move(operand_side_metadata), + std::move(user_side_metadata)); } std::unique_ptr HloInstruction::CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; @@ -1153,19 +1189,22 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: - case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kSort: case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kIota: + case HloOpcode::kDot: + case HloOpcode::kDomain: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1238,11 +1277,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kDot: - CHECK_EQ(new_operands.size(), 2); - clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); @@ -1267,12 +1301,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kDomain: - CHECK_EQ(new_operands.size(), 1); - clone = - CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), - user_side_metadata_->Clone()); - break; case HloOpcode::kAfterAll: if (new_operands.empty()) { clone = CreateToken(); @@ -1281,6 +1309,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } break; } + // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); clone->set_raw_backend_config_string(backend_config_); @@ -1346,7 +1375,7 @@ std::unique_ptr HloInstruction::Clone( // If names ends with .suffix[0-9]+ then replace with a suffix with the // numeric value incremented. int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { clone->name_ = StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); } else { @@ -1464,7 +1493,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { } void HloInstruction::RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices) { + absl::Span ascending_indices) { if (ascending_indices.empty()) { return; } @@ -1567,11 +1596,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAfterAll: return false; - // Check dot dimension numbers. - case HloOpcode::kDot: - return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - other.dot_dimension_numbers()); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1587,10 +1611,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } - case HloOpcode::kDomain: - return operand_side_metadata().Matches(other.operand_side_metadata()) && - user_side_metadata().Matches(other.user_side_metadata()); - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1620,15 +1640,18 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: - case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kScatter: + case HloOpcode::kDot: + case HloOpcode::kDomain: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1818,7 +1841,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) { string HloInstruction::SignatureString() const { string operands = - Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) { StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1838,7 +1861,7 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { } bool HloInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { switch (opcode_) { // Unary elementwise operations. case HloOpcode::kAbs: @@ -1959,13 +1982,13 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - tensorflow::gtl::ArraySlice slice(operands_); + absl::Span slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; if (options.compact_operands() && slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) { // If operand is already been deleted, put `null` to the string output. if (operand == nullptr) { StrAppend(out, "null "); @@ -1985,7 +2008,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } - StrAppend(out, Join(str, " ")); + StrAppend(out, StrJoin(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -1998,10 +2021,6 @@ std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector extra = ExtraAttributesToStringImpl(options); - if (dot_dimension_numbers_ != nullptr) { - extra.push_back(DotDimensionNumbersToString()); - } - if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2027,11 +2046,11 @@ std::vector HloInstruction::ExtraAttributesToString( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, - PrintName(computation->name(), options)); - }))); + "calls=", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }))); } } else if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kFullBodies) { @@ -2064,12 +2083,12 @@ std::vector HloInstruction::ExtraAttributesToString( break; default: if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=\n", - Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, computation->ToString(new_options)); - }))); + extra.push_back(StrCat( + "calls=\n", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); } break; } @@ -2078,30 +2097,25 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", - Join(control_predecessors_, ", ", - [&](string* out, HloInstruction* pre) { - StrAppend(out, - PrintName(pre->name(), options)); - }), + StrJoin(control_predecessors_, ", ", + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); + }), "}")); } - if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { - extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", user_side_metadata_->ToString(), - ", exit=", operand_side_metadata_->ToString(), "}")); - } return extra; } string HloInstruction::ToShortString() const { return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", - Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - StrAppend(out, "%", operand->name()); - }), + StrJoin(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, "%", operand->name()); + }), ")"); } @@ -2129,10 +2143,6 @@ HloInstructionProto HloInstruction::ToProto() const { } } - if (dot_dimension_numbers_ != nullptr) { - *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; - } - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } @@ -2161,7 +2171,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } -bool HloInstruction::IsFusable() const { +bool HloInstruction::IsFusible() const { // Instructions which are traced should not be fused. if (tracing()) { return false; @@ -2265,6 +2275,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllToAll: + return visitor->HandleAllToAll(this); + case HloOpcode::kCollectivePermute: + return visitor->HandleCollectivePermute(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2333,8 +2347,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); - case HloOpcode::kHostCompute: - return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: @@ -2373,15 +2385,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return InternalError( "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " "please file a bug for XLA.", - HloOpcodeString(opcode_).c_str()); + HloOpcodeString(opcode_)); } // Explicit instantiations. template Status HloInstruction::Visit(DfsHloVisitor* visitor); template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); -using DFSStack = - tensorflow::gtl::InlinedVector, 16>; +using DFSStack = absl::InlinedVector, 16>; // Push "child" onto the dfs_stack if not already visited. Returns false if a // cycle was detected, and true otherwise. @@ -2412,7 +2423,7 @@ template static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { - visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); + visitor->ReserveVisitStates(root->GetModule()->instruction_count()); // dfs_stack holds pairs of unique_id(), HloInstruction*>. // @@ -2457,7 +2468,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } @@ -2466,7 +2477,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } } @@ -2626,7 +2637,7 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - return IsElementwiseImpl(tensorflow::gtl::nullopt); + return IsElementwiseImpl(absl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -2714,10 +2725,13 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { case HloOpcode::kTranspose: return UseKind::kUsePermutingElements; case HloOpcode::kPad: - case HloOpcode::kReduce: // Pad reuses the padding value but not the padded array elements. - // Reduce reuses the init value but not the operand array elements. return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; + case HloOpcode::kReduce: + // Reduce reuses the init values but not the operand array elements. + return i >= Cast(this)->input_count() + ? UseKind::kReuse + : UseKind::kUsePermutingElements; case HloOpcode::kFusion: // Uses the memoizing, recursive computation defined above. return FusionReusesParamElements::Compute(i, *fused_expression_root()); @@ -2782,7 +2796,7 @@ StatusOr StringToFusionKind( if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } - return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); + return InvalidArgument("Unknown fusion kind: %s", kind_name); } string PaddingConfigToString(const PaddingConfig& padding) { @@ -2791,7 +2805,7 @@ string PaddingConfigToString(const PaddingConfig& padding) { [](const PaddingConfig::PaddingConfigDimension& dim) { return dim.interior_padding() != 0; }); - return Join( + return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { StrAppend( @@ -2815,11 +2829,15 @@ string OpMetadataToString(const OpMetadata& metadata) { if (metadata.source_line() != 0) { result.push_back(StrCat("source_line=", metadata.source_line())); } - return Join(result, " "); + return StrJoin(result, " "); } string RandomDistributionToString(const RandomDistribution& distribution) { - return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); + return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); +} + +string PrecisionToString(const PrecisionConfig::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2847,31 +2865,8 @@ string ConvolutionDimensionNumbersToString( output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", - Join(output_dims, "")); -} - -string HloInstruction::DotDimensionNumbersToString() const { - std::vector result; - if (dot_dimension_numbers_ == nullptr) { - return ""; - } - const DotDimensionNumbers& dnums = *dot_dimension_numbers_; - if (!dnums.lhs_batch_dimensions().empty()) { - result.push_back(StrCat("lhs_batch_dims={", - Join(dnums.lhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("lhs_contracting_dims={", - Join(dnums.lhs_contracting_dimensions(), ","), "}")); - - if (!dnums.rhs_batch_dimensions().empty()) { - result.push_back(StrCat("rhs_batch_dims={", - Join(dnums.rhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("rhs_contracting_dims={", - Join(dnums.rhs_contracting_dimensions(), ","), "}")); - - return Join(result, ", "); + return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->", + StrJoin(output_dims, "")); } StatusOr StringToRandomDistribution(const string& name) { @@ -2885,7 +2880,26 @@ StatusOr StringToRandomDistribution(const string& name) { } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +StatusOr StringToPrecision(const string& name) { + static std::unordered_map* map = [] { + static auto* map = + new std::unordered_map; + for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) { + if (PrecisionConfig::Precision_IsValid(i)) { + auto value = static_cast(i); + (*map)[PrecisionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } @@ -2896,6 +2910,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } +bool HloPtrComparator::operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } + auto lhs_module = lhs->GetModule(); + auto rhs_module = rhs->GetModule(); + CHECK((lhs_module == nullptr && rhs_module == nullptr) || + (lhs_module != nullptr && rhs_module != nullptr)); + if (lhs_module != nullptr && + lhs_module->unique_id() != rhs_module->unique_id()) { + return lhs_module->unique_id() < rhs_module->unique_id(); + } + return lhs->unique_id() < rhs->unique_id(); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -2932,6 +2966,16 @@ Status HloInstruction::set_backend_config( return ret; } +const PrecisionConfig& HloInstruction::precision_config() const { + if (auto* convolution = DynCast(this)) { + return convolution->precision_config(); + } + if (auto* dot = DynCast(this)) { + return dot->precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3135,8 +3179,13 @@ const string& HloInstruction::outfeed_config() const { return Cast(this)->outfeed_config(); } -const std::vector& HloInstruction::replica_group_ids() const { - return Cast(this)->replica_group_ids(); +const std::vector& HloInstruction::replica_groups() const { + return Cast(this)->replica_groups(); +} + +const std::vector>& +HloInstruction::source_target_pairs() const { + return Cast(this)->source_target_pairs(); } string HloInstruction::cross_replica_sum_barrier() const { @@ -3148,7 +3197,7 @@ void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { barrier); } -tensorflow::gtl::optional HloInstruction::all_reduce_id() const { +absl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } @@ -3174,6 +3223,18 @@ void HloInstruction::set_convolution_dimension_numbers( } } +int64 HloInstruction::feature_group_count() const { + if (auto convolution = DynCast(this)) { + return convolution->feature_group_count(); + } + return Cast(this)->feature_group_count(); +} + +void HloInstruction::set_feature_group_count(int64 feature_group_count) { + Cast(this)->set_feature_group_count( + feature_group_count); +} + HloComputation* HloInstruction::select() const { return Cast(this)->select(); } @@ -3194,10 +3255,6 @@ const string& HloInstruction::custom_call_target() const { return Cast(this)->custom_call_target(); } -const string& HloInstruction::channel_name() const { - return Cast(this)->channel_name(); -} - const PaddingConfig& HloInstruction::padding_config() const { return Cast(this)->padding_config(); } @@ -3214,9 +3271,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { return Cast(this)->gather_dimension_numbers(); } -tensorflow::gtl::ArraySlice HloInstruction::gather_window_bounds() - const { - return Cast(this)->gather_window_bounds(); +absl::Span HloInstruction::gather_slice_sizes() const { + return Cast(this)->gather_slice_sizes(); } const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() @@ -3224,4 +3280,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() return Cast(this)->scatter_dimension_numbers(); } +const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { + return Cast(this)->dot_dimension_numbers(); +} + +const DomainMetadata& HloInstruction::operand_side_metadata() const { + return Cast(this)->operand_side_metadata(); +} + +const DomainMetadata& HloInstruction::user_side_metadata() const { + return Cast(this)->user_side_metadata(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e722086732947c41c9b1bfa76fe88fe35c3e45d6..d615df0831f6306b75e099ee10353a199878b42b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,11 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -45,10 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -80,6 +82,7 @@ class HloPrintOptions { print_operand_shape_(true), print_program_shape_(true), print_percent_(true), + print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} @@ -92,7 +95,8 @@ class HloPrintOptions { .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) - .set_print_percent(false); + .set_print_percent(false) + .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic @@ -101,10 +105,12 @@ class HloPrintOptions { return HloPrintOptions() .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) .set_print_metadata(false) + .set_print_backend_config(false) .set_compact_operands(true) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) + .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } @@ -150,6 +156,12 @@ class HloPrintOptions { return *this; } + // If true, control dependencies will be printed. + HloPrintOptions& set_print_control_dependencies(bool value) { + print_control_dependencies_ = value; + return *this; + } + // If true, only a part of operands will be printed out, and their names will // be omitted (note that in this case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { @@ -182,11 +194,14 @@ class HloPrintOptions { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } - bool print_backend_config() const { return print_metadata_; } + bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool print_control_dependencies() const { + return print_control_dependencies_; + } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } @@ -202,6 +217,7 @@ class HloPrintOptions { bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; @@ -220,7 +236,7 @@ class CanonicalNameMap { return iter->second; } - string new_name = tensorflow::strings::StrCat("tmp_", index++); + string new_name = absl::StrCat("tmp_", index++); canonical_name_map[old_name] = new_name; return new_name; } @@ -343,11 +359,11 @@ class HloInstruction { const string& name); // Creates a literal constant instruction. - static std::unique_ptr CreateConstant( - std::unique_ptr literal); + static std::unique_ptr CreateConstant(Literal literal); // Creates an Iota instruction. - static std::unique_ptr CreateIota(const Shape& shape); + static std::unique_ptr CreateIota(const Shape& shape, + int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( @@ -361,7 +377,7 @@ class HloInstruction { // random numbers from a given distribution. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + absl::Span parameters); // Creates a unary instruction (one operand). // Precondition: opcode must be a legitimate unary operation. @@ -388,38 +404,34 @@ class HloInstruction { // Precondition: opcode must be a legitimate variadic operation. static std::unique_ptr CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, // at a given index) static std::unique_ptr CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers); - - // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 - // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS - // and the RHS must be of rank 2. - static std::unique_ptr CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to @@ -432,9 +444,10 @@ class HloInstruction { // // `reduction_computation`: the reduction function. // - // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). + // Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // `all_reduce_id`: for Allreduce nodes from different modules, if they have @@ -443,11 +456,36 @@ class HloInstruction { // // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id); + const std::vector& replica_groups, + absl::string_view barrier, const absl::optional& all_reduce_id); + + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + static std::unique_ptr CreateAllToAll( + const Shape& shape, absl::Span operands, + const std::vector& replica_groups); + + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -466,24 +504,13 @@ class HloInstruction { static std::unique_ptr CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config); - // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of infeed are - // converted to take tokens. - static std::unique_ptr CreateInfeed(const Shape& infeed_shape, - const string& config); // Creates an outfeed instruction, which outputs data. outfeed_shape is the // shape of the data being outfed *not* the shape of the outfeed instruction // which is a TOKEN. static std::unique_ptr CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); - // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of outfeed are - // converted to take tokens. - static std::unique_ptr CreateOutfeed( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in @@ -516,17 +543,15 @@ class HloInstruction { // start/limit indices. static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + HloInstruction* start_indices, absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. @@ -537,7 +562,7 @@ class HloInstruction { // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. static std::unique_ptr CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) @@ -549,7 +574,7 @@ class HloInstruction { // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // A more general, multiple-argument version of the above. @@ -564,9 +589,9 @@ class HloInstruction { // ... // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // Creates a reduce-window instruction, where the computation (given @@ -603,7 +628,7 @@ class HloInstruction { // Creates a broadcast instruction. static std::unique_ptr CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. @@ -633,7 +658,7 @@ class HloInstruction { // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a sort op, with a keys operand, and an optional values operand. static std::unique_ptr CreateSort( @@ -657,9 +682,9 @@ class HloInstruction { static std::unique_ptr CreateGather( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); static std::unique_ptr CreateScatter( const Shape& shape, HloInstruction* operand, @@ -683,43 +708,37 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation); // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation); // Creates a custom call instruction that applies the given custom call target // to the given operands. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target); - - // Creates a HostCompute instruction, which records host-side control and - // data dependencies for use in instruction scheduling. - static std::unique_ptr CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); // Creates a reverse instruction, which reverses the order of the elements // in the specified dimensions. static std::unique_ptr CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Afterall instruction used for joining or creating new values of // token type which thread through side-effecting operations. Operands must // all be tokens, and there must be at least one operand. static std::unique_ptr CreateAfterAll( - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates an AfterAll instruction which creates a token type out of thin air // (no operands). This is a separate method from CreateAfterAll to facility @@ -756,7 +775,7 @@ class HloInstruction { int64 operand_count() const { return operands_.size(); } // Returns the vector of operands of this instruction. - using InstructionVector = tensorflow::gtl::InlinedVector; + using InstructionVector = absl::InlinedVector; const InstructionVector& operands() const { return operands_; } // Returns the vector of unique operands, in the same order they are found @@ -1020,7 +1039,7 @@ class HloInstruction { // Returns true if this instruction can be legally fused into a fusion // instruction. - bool IsFusable() const; + bool IsFusible() const; // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. @@ -1028,21 +1047,26 @@ class HloInstruction { CHECK(has_sharding()); return *sharding_; } + std::shared_ptr sharding_ptr() const { return sharding_; } + // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } // Returns the sharding unique device, if any. - tensorflow::gtl::optional sharding_unique_device() const { + absl::optional sharding_unique_device() const { if (sharding_ == nullptr) { - return tensorflow::gtl::optional(); + return absl::optional(); } return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = MakeUnique(sharding); + sharding_ = std::make_shared(sharding); + } + void set_sharding(std::shared_ptr sharding) { + sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1062,15 +1086,6 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } - // Retrieves the operand side metadata of a kDomain instruction. - const DomainMetadata& operand_side_metadata() const { - return *operand_side_metadata_; - } - // Retrieves the user side metadata of a kDomain instruction. - const DomainMetadata& user_side_metadata() const { - return *user_side_metadata_; - } - // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1078,28 +1093,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // TODO(b/80249101): Remove these methods once HLO scheduling and copy - // insertion are integrated, and we don't need to run a separate pass - // of copy elision anymore. - bool CopyElisionAllowed() const { - CHECK_EQ(HloOpcode::kCopy, opcode_); - return copy_elision_allowed_; - } - - void SetCopyElisionAllowed(bool value) { - CHECK_EQ(HloOpcode::kCopy, opcode_); - copy_elision_allowed_ = value; - } - - // Returns data on the dimension numbers used for a dot operation. - const DotDimensionNumbers& dot_dimension_numbers() const { - CHECK(dot_dimension_numbers_ != nullptr); - return *dot_dimension_numbers_; - } - - // Returns the dump string of the dot dimension numbers. - string DotDimensionNumbersToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1110,8 +1103,7 @@ class HloInstruction { // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). @@ -1243,6 +1235,16 @@ class HloInstruction { static StatusOr BackendConfigToRawString( const tensorflow::protobuf::Message& proto); + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + // Precondition: opcode must be kConvolution or kDot. + const PrecisionConfig& precision_config() const; + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1411,15 +1413,18 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllReduceInstruction::replica_group_ids. - const std::vector& replica_group_ids() const; + // Delegates to HloCollectiveInstruction::replica_groups. + const std::vector& replica_groups() const; + + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector>& source_target_pairs() const; // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. - tensorflow::gtl::optional all_reduce_id() const; + absl::optional all_reduce_id() const; // Returns data on the window in a windowed operation such as // convolution. @@ -1443,6 +1448,12 @@ class HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums); + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const; + + void set_feature_group_count(int64 feature_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; @@ -1458,9 +1469,6 @@ class HloInstruction { // Delegates to HloCustomCallInstruction::custom_call_target. const string& custom_call_target() const; - // Delegates to HloHostComputeInstruction::channel_name. - const string& channel_name() const; - // Delegates to HloPadInstruction::padding_config. const PaddingConfig& padding_config() const; @@ -1472,12 +1480,21 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; - // Delegates to HloGatherInstruction::gather_window_bounds. - tensorflow::gtl::ArraySlice gather_window_bounds() const; + // Delegates to HloGatherInstruction::gather_slice_sizes. + absl::Span gather_slice_sizes() const; // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Delegates to HloDotInstruction::dot_dimension_numbers(). + const DotDimensionNumbers& dot_dimension_numbers() const; + + // Delegates to HloDomainInstruction::operand_side_metadata(). + const DomainMetadata& operand_side_metadata() const; + + // Delegates to HloDomainInstruction::user_side_metadata(). + const DomainMetadata& user_side_metadata() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1499,7 +1516,7 @@ class HloInstruction { // Removes a list of operands with the given indices in ascending order. void RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices); + absl::Span ascending_indices); void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); @@ -1529,8 +1546,7 @@ class HloInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { // TODO(b/80131774): This should be pure virtual. LOG(FATAL) << "Unimplemented method."; @@ -1548,7 +1564,7 @@ class HloInstruction { // NOTE: For all instructions other than kFusion, being elementwise on one of // the operands is equivalent to being elementwise on all the operands. virtual bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const; + const absl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1576,7 +1592,7 @@ class HloInstruction { // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Adds a user for this instruction. void AddUser(HloInstruction* user); @@ -1600,6 +1616,10 @@ class HloInstruction { InstructionVector operands_; // The set of control predecessors of this instruction. + // Note that the order of the instructions in the vector influences the order + // computed in HloComputation::ComputeInstructionPostOrder, which may + // influence the result of the compilation by changing the scheduling. We are + // not sure if it matters. std::vector control_predecessors_; // The users of this instruction. Users are HLOs where this instruction is an @@ -1618,18 +1638,11 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the dimension numbers used for a dot. - std::unique_ptr dot_dimension_numbers_; - - // Used to tag kCopy instructions that are eligible for copy elision. - bool copy_elision_allowed_ = true; - // The sharding, if one exists. - std::unique_ptr sharding_; - - // Fields used by the kDomain instruction. - std::unique_ptr operand_side_metadata_; - std::unique_ptr user_side_metadata_; + // Uses std::shared_ptr to allow reuse of the same sharding object between + // HloInstructions and other components as HloSharding can be very large for + // many element tuples. + std::shared_ptr sharding_; // Computations called by this instruction. std::vector called_computations_; @@ -1666,10 +1679,12 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr StringToRandomDistribution(const string& name); +StatusOr StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1678,21 +1693,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. Exception: null pointer values compare less than non-null. -// -// Note that this cannot be used for HLO instructions across multiple modules -// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, - const HloInstruction* const& rhs) const { - if (rhs == nullptr) { - // Nothing compares less than nullptr. - return false; - } - if (lhs == nullptr) { - return true; - } - return lhs->unique_id() < rhs->unique_id(); - } + const HloInstruction* const& rhs) const; }; template diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8a694dde8066ab9a1138b9f7981153d451ddb89e..c1b7c3832b44b5d65b715dffa5211a5c92e17953 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,10 +39,8 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloTestBase { +class HloInstructionTest : public HloVerifiedTestBase { protected: - HloInstructionTest() {} - Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -53,7 +51,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { public: Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("not implemented %s", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } Status HandleParameter(HloInstruction* parameter) override { @@ -1086,16 +1084,14 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { // Fused expression: - // - // x y - // \ / \ - // min broadcast + // y + // / + // x broadcast + // \ / | + // min | // \ / // sub // - // The fusion instruction is elementwise on `x` because the only path from x - // to sub contains only elementwise operations. It is not elementwise on `y` - // because the path y->broadcast->sub is not all elementwise. const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); @@ -1104,10 +1100,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y")); - HloInstruction* min = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y)); HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0})); + builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {})); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); @@ -1118,10 +1114,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { EXPECT_FALSE(fusion->IsElementwise()); for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); ++operand_idx) { - if (fusion->operand(operand_idx) == x) { - EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); - } else { + if (fusion->operand(operand_idx) == y) { EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); + } else { + EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); } } } @@ -1151,8 +1147,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1192,8 +1188,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1243,12 +1239,12 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2))); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kAdd, dot, add_operand)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -1324,8 +1320,8 @@ TEST_F(HloInstructionTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().set_print_metadata(false); @@ -1355,7 +1351,7 @@ TEST_F(HloInstructionTest, Stringification) { TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); @@ -1363,19 +1359,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1383,15 +1378,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=4, window_bounds={30,29,28,27,26}"); + "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=4, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyGather_1) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); @@ -1399,19 +1394,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1419,10 +1413,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=2, window_bounds={30,29,28,27,26}"); + "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=2, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyScatter) { @@ -1491,8 +1485,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().Canonical(); @@ -1533,8 +1527,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1589,8 +1583,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1745,5 +1739,23 @@ TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { << clone->convolution_dimension_numbers().DebugString(); } +TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { + constexpr char kHloString[] = R"( + HloModule test_module + ENTRY test { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, + dim_labels=b0f_0io->b0f, operand_precision={high,default} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString)); + auto* conv = module->entry_computation()->root_instruction(); + + auto clone = conv->Clone(); + EXPECT_THAT( + clone->precision_config().operand_precision(), + ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1d71a74c4092291cfd29b9026e50676e1661aad1..e92882c22a6ef1dd43440d3c94c7d233c9a4fb5d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,12 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -27,10 +33,10 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { @@ -41,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -85,11 +112,10 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( std::unique_ptr HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); } @@ -107,11 +133,10 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( std::unique_ptr HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -129,18 +154,17 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction( std::unique_ptr HloBatchNormGradInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } -HloFftInstruction::HloFftInstruction( - const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) +HloFftInstruction::HloFftInstruction(const Shape& shape, + HloInstruction* operand, FftType fft_type, + absl::Span fft_length) : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { fft_length_.assign(fft_length.begin(), fft_length.end()); AppendOperand(operand); @@ -158,7 +182,7 @@ HloInstructionProto HloFftInstruction::ToProto() const { std::vector HloFftInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {StrCat("fft_type=", FftType_Name(fft_type())), - StrCat("fft_length={", Join(fft_length(), ","), "}")}; + StrCat("fft_length={", StrJoin(fft_length(), ","), "}")}; } bool HloFftInstruction::IdenticalSlowPath( @@ -171,12 +195,11 @@ bool HloFftInstruction::IdenticalSlowPath( } std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], fft_type_, - fft_length_); + return absl::make_unique(shape, new_operands[0], fft_type_, + fft_length_); } HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, @@ -226,12 +249,11 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand, } std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(new_operands[0], new_operands[1], - channel_id(), is_host_transfer()); + return absl::make_unique( + new_operands[0], new_operands[1], channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, @@ -244,11 +266,10 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, std::unique_ptr HloSendDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } @@ -265,11 +286,10 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape, } std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), is_host_transfer()); } @@ -287,35 +307,69 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, std::unique_ptr HloRecvDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } -HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id) - : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), - all_reduce_id_(all_reduce_id) { +HloCollectiveInstruction::HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + const std::vector& replica_groups) + : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { AppendOperand(operand); } +} + +HloInstructionProto HloCollectiveInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + return proto; +} + +std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + std::vector replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", StrJoin(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}")); + return result; +} + +bool HloCollectiveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return absl::c_equal(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return absl::c_equal(a.replica_ids(), b.replica_ids()); + }); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, absl::Span operands, + HloComputation* reduce_computation, + const std::vector& replica_groups, absl::string_view barrier, + const absl::optional& all_reduce_id) + : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + replica_groups), + cross_replica_sum_barrier_(barrier), + all_reduce_id_(all_reduce_id) { AppendComputation(reduce_computation); } HloInstructionProto HloAllReduceInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - for (int64 i : replica_group_ids_) { - proto.add_replica_group_ids(i); - } + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); @@ -325,9 +379,9 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { } std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& /*options*/) const { - std::vector result = { - StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -342,7 +396,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return replica_group_ids() == casted_other.replica_group_ids() && + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -351,17 +405,80 @@ bool HloAllReduceInstruction::IdenticalSlowPath( std::unique_ptr HloAllReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { - return MakeUnique( - shape, new_operands, to_apply(), replica_group_ids(), + return absl::make_unique( + shape, new_operands, to_apply(), replica_groups(), cross_replica_sum_barrier(), all_reduce_id()); } -HloReverseInstruction::HloReverseInstruction( +HloAllToAllInstruction::HloAllToAllInstruction( + const Shape& shape, absl::Span operands, + const std::vector& replica_groups) + : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, + replica_groups) {} + +std::unique_ptr +HloAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique(shape, new_operands, + replica_groups()); +} + +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) + const std::vector>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} + +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + return proto; +} + +std::vector +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + std::vector strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); + return result; +} + +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return absl::c_equal(source_target_pairs(), + casted_other.source_target_pairs(), + [](const std::pair& a, + const std::pair& b) { return a == b; }); +} + +std::unique_ptr +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], source_target_pairs()); +} + +HloReverseInstruction::HloReverseInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span dimensions) : HloInstruction(HloOpcode::kReverse, shape), dimensions_(dimensions.begin(), dimensions.end()) { AppendOperand(operand); @@ -377,7 +494,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const { std::vector HloReverseInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReverseInstruction::IdenticalSlowPath( @@ -389,16 +506,15 @@ bool HloReverseInstruction::IdenticalSlowPath( } std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloConcatenateInstruction::HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { for (auto operand : operands) { @@ -416,7 +532,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const { std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloConcatenateInstruction::IdenticalSlowPath( @@ -430,16 +546,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath( std::unique_ptr HloConcatenateInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, - dimensions(0)); + return absl::make_unique(shape, new_operands, + dimensions(0)); } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span args, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { @@ -459,7 +574,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const { std::vector HloReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReduceInstruction::IdenticalSlowPath( @@ -474,12 +589,11 @@ bool HloReduceInstruction::IdenticalSlowPath( } std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands, dimensions(), - to_apply()); + CHECK_EQ(new_operands.size() % 2, 0); + return absl::make_unique(shape, new_operands, + dimensions(), to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, @@ -502,7 +616,7 @@ HloInstructionProto HloSortInstruction::ToProto() const { std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloSortInstruction::IdenticalSlowPath( @@ -514,17 +628,17 @@ bool HloSortInstruction::IdenticalSlowPath( } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; - return MakeUnique(shape, dimensions(0), keys, values); + return absl::make_unique(shape, dimensions(0), keys, + values); } HloTransposeInstruction::HloTransposeInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { CHECK_EQ(shape.dimensions().size(), dimensions.size()); @@ -534,7 +648,7 @@ HloTransposeInstruction::HloTransposeInstruction( Permute(dimensions, shape.dimensions()).begin())) << "shape: " << ShapeUtil::HumanString(shape) << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; + << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -555,7 +669,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const { std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloTransposeInstruction::IdenticalSlowPath( @@ -568,17 +682,16 @@ bool HloTransposeInstruction::IdenticalSlowPath( std::unique_ptr HloTransposeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloBroadcastInstruction::HloBroadcastInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension) + absl::Span broadcast_dimension) : HloInstruction(HloOpcode::kBroadcast, shape), dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { AppendOperand(operand); @@ -594,7 +707,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const { std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloBroadcastInstruction::IdenticalSlowPath( @@ -607,17 +720,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath( std::unique_ptr HloBroadcastInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } -HloMapInstruction::HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) +HloMapInstruction::HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { for (auto operand : operands) { AppendOperand(operand); @@ -638,7 +750,7 @@ HloInstructionProto HloMapInstruction::ToProto() const { } bool HloMapInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { if (!dimensions().empty()) { // Check that the map is executed in elementwise compatible dimensions. if (dimensions().size() != shape().dimensions_size()) { @@ -655,7 +767,7 @@ bool HloMapInstruction::IsElementwiseImpl( std::vector HloMapInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloMapInstruction::IdenticalSlowPath( @@ -666,17 +778,16 @@ bool HloMapInstruction::IdenticalSlowPath( } std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, to_apply()); + return absl::make_unique(shape, new_operands, to_apply()); } -HloSliceInstruction::HloSliceInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) +HloSliceInstruction::HloSliceInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) : HloInstruction(HloOpcode::kSlice, shape), slice_starts_(start_indices.begin(), start_indices.end()), slice_limits_(limit_indices.begin(), limit_indices.end()), @@ -713,7 +824,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); } - return {StrCat("slice={", Join(bounds, ", "), "}")}; + return {StrCat("slice={", StrJoin(bounds, ", "), "}")}; } bool HloSliceInstruction::IdenticalSlowPath( @@ -727,16 +838,15 @@ bool HloSliceInstruction::IdenticalSlowPath( } std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], slice_starts_, - slice_limits_, slice_strides_); + return absl::make_unique( + shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } -HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) - : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), +HloConstantInstruction::HloConstantInstruction(Literal literal) + : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} HloConstantInstruction::HloConstantInstruction(const Shape& shape) @@ -744,14 +854,14 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - if (literal_ != nullptr) { + if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } return proto; } bool HloConstantInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { return true; } @@ -766,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, if (!mutable_array_subshape->has_layout() || !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); + *literal_ = literal_->Relayout(new_layout, shape_index); *mutable_array_subshape->mutable_layout() = new_layout; } } @@ -781,10 +891,10 @@ bool HloConstantInstruction::IdenticalSlowPath( std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(literal_->CloneToUnique()); + CHECK(literal_.has_value()); + return absl::make_unique(literal_->Clone()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -792,14 +902,14 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if (literal_ != nullptr && + if (literal_.has_value() && ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); + std::vector v = absl::StrSplit(tmp, ' '); bool first = true; // Concatenate elements in "v" with spaces separating them, but ignoring // empty entries. @@ -827,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstructionProto HloTraceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + *proto.mutable_literal() = literal_.ToProto(); return proto; } @@ -839,8 +949,7 @@ bool HloTraceInstruction::IdenticalSlowPath( } std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); } @@ -858,7 +967,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape, HloFusionInstruction::HloFusionInstruction( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { for (auto operand : operands) { @@ -891,7 +1000,7 @@ HloInstructionProto HloFusionInstruction::ToProto() const { } bool HloFusionInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { if (!operand_idx.has_value()) { for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { @@ -1094,7 +1203,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal( HloInstruction* HloFusionInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations().empty()) { @@ -1265,8 +1374,7 @@ bool HloFusionInstruction::IdenticalSlowPath( } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloModule* module = context != nullptr ? context->module() : GetModule(); HloComputation* new_fused_computation = nullptr; @@ -1278,8 +1386,8 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation = module->AddEmbeddedComputation( fused_instructions_computation()->Clone("clone", context)); } - return MakeUnique(shape, fusion_kind(), new_operands, - new_fused_computation); + return absl::make_unique( + shape, fusion_kind(), new_operands, new_fused_computation); } Status HloFusionInstruction::DeduplicateFusionOperands() { @@ -1304,7 +1412,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() { HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) + absl::Span parameters) : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { for (HloInstruction* param : parameters) { AppendOperand(param); @@ -1323,7 +1431,7 @@ std::vector HloRngInstruction::ExtraAttributesToStringImpl( } bool HloRngInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { return true; } @@ -1335,10 +1443,10 @@ bool HloRngInstruction::IdenticalSlowPath( } std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(shape, distribution_, new_operands); + return absl::make_unique(shape, distribution_, + new_operands); } HloParameterInstruction::HloParameterInstruction(int64 parameter_number, @@ -1371,10 +1479,10 @@ bool HloParameterInstruction::IdenticalSlowPath( std::unique_ptr HloParameterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(parameter_number_, shape, name()); + return absl::make_unique(parameter_number_, shape, + name()); } HloGetTupleElementInstruction::HloGetTupleElementInstruction( @@ -1406,12 +1514,11 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath( std::unique_ptr HloGetTupleElementInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - tuple_index()); + return absl::make_unique( + shape, new_operands[0], tuple_index()); } HloReducePrecisionInstruction::HloReducePrecisionInstruction( @@ -1449,11 +1556,10 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath( std::unique_ptr HloReducePrecisionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], exponent_bits(), mantissa_bits()); } @@ -1467,13 +1573,6 @@ HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, AppendOperand(token_operand); } -HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, - const string& config) - : HloInstruction(HloOpcode::kInfeed, - ShapeUtil::MakeTupleShape( - {infeed_shape, ShapeUtil::MakeTokenShape()})), - infeed_config_(config) {} - HloInstructionProto HloInfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_infeed_config(infeed_config_); @@ -1497,24 +1596,20 @@ bool HloInfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - if (new_operands.empty()) { - return MakeUnique(infeed_shape(), infeed_config()); - } else { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique(infeed_shape(), new_operands[0], - infeed_config()); - } + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + infeed_shape(), new_operands[0], infeed_config()); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) +HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + absl::string_view outfeed_config) : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + outfeed_config_(outfeed_config) { CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) << "Outfeed shape " << outfeed_shape << " must be compatible with operand shape " << operand->shape(); @@ -1522,18 +1617,6 @@ HloOutfeedInstruction::HloOutfeedInstruction( AppendOperand(token_operand); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config) - : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), - outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); - AppendOperand(operand); -} - HloInstructionProto HloOutfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_outfeed_config(outfeed_config()); @@ -1558,25 +1641,23 @@ bool HloOutfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - if (new_operands.size() == 1) { - return MakeUnique(outfeed_shape(), new_operands[0], - outfeed_config()); - } else { - CHECK_EQ(new_operands.size(), 2); - return MakeUnique(outfeed_shape(), new_operands[0], - new_operands[1], outfeed_config()); - } + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + outfeed_shape(), new_operands[0], new_operands[1], outfeed_config()); } HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), + feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1603,6 +1684,8 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_window() = window_; *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; + proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1614,6 +1697,15 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1623,21 +1715,25 @@ bool HloConvolutionInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast(other); + if (feature_group_count_ != other.feature_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr HloConvolutionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], - new_operands[1], window(), - convolution_dimension_numbers_); + return absl::make_unique( + shape, new_operands[0], new_operands[1], feature_group_count_, window(), + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1676,11 +1772,10 @@ bool HloReduceWindowInstruction::IdenticalSlowPath( std::unique_ptr HloReduceWindowInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], window(), to_apply()); } @@ -1725,21 +1820,20 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath( std::unique_ptr HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], select(), window(), new_operands[1], new_operands[2], scatter()); } HloCustomCallInstruction::HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), - custom_call_target_(custom_call_target.begin(), - custom_call_target.end()) { + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + feature_group_count_(1) { for (auto operand : operands) { AppendOperand(operand); } @@ -1755,6 +1849,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1769,6 +1864,9 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( "dim_labels=", ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1796,60 +1894,28 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.convolution_dimension_numbers()))) { return false; } + if (feature_group_count_ != casted_other.feature_group_count_) { + return false; + } return custom_call_target_ == casted_other.custom_call_target_; } std::unique_ptr HloCustomCallInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - auto cloned = MakeUnique(shape, new_operands, - custom_call_target()); + auto cloned = absl::make_unique( + shape, new_operands, custom_call_target()); if (window_ != nullptr) { cloned->set_window(*window_); } if (convolution_dimension_numbers_ != nullptr) { cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } + cloned->set_feature_group_count(feature_group_count_); return std::move(cloned); } -HloHostComputeInstruction::HloHostComputeInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) - : HloInstruction(HloOpcode::kHostCompute, shape), - channel_name_(channel_name.begin(), channel_name.end()), - cost_estimate_ns_(cost_estimate_ns) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -HloInstructionProto HloHostComputeInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - return proto; -} - -bool HloHostComputeInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const { - // Not yet supported. - return false; -} - -std::unique_ptr -HloHostComputeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, - HloCloneContext* context) const { - return MakeUnique( - shape, new_operands, channel_name_, cost_estimate_ns_); -} - HloPadInstruction::HloPadInstruction(const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, @@ -1880,17 +1946,16 @@ bool HloPadInstruction::IdenticalSlowPath( } std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], new_operands[1], - padding_config_); + return absl::make_unique(shape, new_operands[0], + new_operands[1], padding_config_); } HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); @@ -1907,8 +1972,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const { std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return { - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; + return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","), + "}")}; } bool HloDynamicSliceInstruction::IdenticalSlowPath( @@ -1920,60 +1985,57 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath( std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } HloGatherInstruction::HloGatherInstruction( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); - AppendOperand(gather_indices); + AppendOperand(start_indices); gather_dimension_numbers_ = - MakeUnique(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); + absl::make_unique(gather_dim_numbers); + absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string offset_dims = + StrCat("offset_dims={", + StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = StrCat( + "collapsed_slice_dims={", + StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + string start_index_map = + StrCat("start_index_map={", + StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, + return StrJoin>( + {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim) { + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); + for (int64 output_window_dim : offset_dims) { + gather_dim_numbers.add_offset_dims(output_window_dim); } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); + for (int64 elided_window_dim : collapsed_slice_dims) { + gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim); } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + for (int64 gather_dim_to_input_dim : start_index_map) { + gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim); } gather_dim_numbers.set_index_vector_dim(index_vector_dim); @@ -1983,8 +2045,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { HloInstructionProto HloGatherInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); + for (int64 bound : gather_slice_sizes()) { + proto.add_gather_slice_sizes(bound); } return proto; } @@ -1992,7 +2054,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; + StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2003,17 +2065,16 @@ bool HloGatherInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( gather_dimension_numbers(), casted_other.gather_dimension_numbers()) && - gather_window_bounds() == casted_other.gather_window_bounds(); + gather_slice_sizes() == casted_other.gather_slice_sizes(); } std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), - gather_window_bounds()); + gather_slice_sizes()); } HloScatterInstruction::HloScatterInstruction( @@ -2027,24 +2088,24 @@ HloScatterInstruction::HloScatterInstruction( AppendOperand(updates); AppendComputation(update_computation); scatter_dimension_numbers_ = - MakeUnique(scatter_dim_numbers); + absl::make_unique(scatter_dim_numbers); } string HloScatterInstruction::ScatterDimensionNumbersToString() const { - string update_window_dims = - StrCat("update_window_dims={", - Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string update_window_dims = StrCat( + "update_window_dims={", + StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}"); string inserted_window_dims = StrCat( "inserted_window_dims={", - Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); string scatter_dims_to_operand_dims = StrCat( "scatter_dims_to_operand_dims={", - Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); - return Join>( + return StrJoin>( {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim}, ", "); @@ -2052,9 +2113,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const { /* static */ ScatterDimensionNumbers HloScatterInstruction::MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim) { ScatterDimensionNumbers scatter_dim_numbers; for (int64 update_window_dim : update_window_dims) { @@ -2094,13 +2155,150 @@ bool HloScatterInstruction::IdenticalSlowPath( } std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + return absl::make_unique(shape, iota_dimension()); +} + +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config_; + return proto; +} + +std::vector HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); +} + +std::unique_ptr HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config_); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b0388223376a65dbe86ee273246d2ace229ada13..2d7bc83855e761ed313d831a1252a54130910bbe 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -66,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -81,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -96,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -105,7 +103,7 @@ class HloFftInstruction : public HloInstruction { public: explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); FftType fft_type() const { return fft_type_; } const std::vector& fft_length() const { return fft_length_; } @@ -123,8 +121,7 @@ class HloFftInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes FFT type for an FFT instruction. @@ -173,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -186,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -199,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -212,24 +206,41 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; -class HloAllReduceInstruction : public HloInstruction { +class HloCollectiveInstruction : public HloInstruction { + public: + const std::vector& replica_groups() const { + return replica_groups_; + } + + protected: + explicit HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + const std::vector& replica_groups); + + HloInstructionProto ToProto() const override; + + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + std::vector replica_groups_; +}; + +class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id); - - // Returns the group ids of each replica for CrossReplicaSum op. - const std::vector& replica_group_ids() const { - return replica_group_ids_; - } + const std::vector& replica_groups, + absl::string_view barrier, const absl::optional& all_reduce_id); // Returns the barrier config used for the CrossReplicaSum implementation of // each backend. @@ -240,9 +251,7 @@ class HloAllReduceInstruction : public HloInstruction { cross_replica_sum_barrier_ = barrier; } - tensorflow::gtl::optional all_reduce_id() const { - return all_reduce_id_; - } + absl::optional all_reduce_id() const { return all_reduce_id_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -257,26 +266,64 @@ class HloAllReduceInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // The group id of each replica for CrossReplicaSum. - std::vector replica_group_ids_; - // The string representation of the barrier config used for CrossReplicaSum. string cross_replica_sum_barrier_; // For Allreduce nodes from different modules, if they have the same // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. - tensorflow::gtl::optional all_reduce_id_; + absl::optional all_reduce_id_; +}; + +class HloAllToAllInstruction : public HloCollectiveInstruction { + public: + explicit HloAllToAllInstruction( + const Shape& shape, absl::Span operands, + const std::vector& replica_groups); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; +}; + +class HloCollectivePermuteInstruction : public HloInstruction { + public: + explicit HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); + + const std::vector>& source_target_pairs() const { + return source_target_pairs_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + const std::vector> source_target_pairs_; }; class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -292,8 +339,7 @@ class HloReverseInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -301,9 +347,9 @@ class HloReverseInstruction : public HloInstruction { class HloConcatenateInstruction : public HloInstruction { public: - explicit HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - int64 dimension); + explicit HloConcatenateInstruction(const Shape& shape, + absl::Span operands, + int64 dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -321,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -330,16 +375,30 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: - explicit HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reduce_computation); + explicit HloReduceInstruction(const Shape& shape, + absl::Span args, + absl::Span dimensions_to_reduce, + HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the number of input arrays (and, consequentially, the number of + // init values) this reduce has. + int64 input_count() const { return operand_count() / 2; } + + // Returns the input tensors to be reduced. + absl::Span inputs() const { + return absl::MakeSpan(operands()).subspan(0, input_count()); + } + + // Returns the init values of the reduction. + absl::Span init_values() const { + return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -349,8 +408,7 @@ class HloReduceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -378,8 +436,7 @@ class HloSortInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -387,9 +444,8 @@ class HloSortInstruction : public HloInstruction { class HloTransposeInstruction : public HloInstruction { public: - explicit HloTransposeInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -407,8 +463,7 @@ class HloTransposeInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -416,9 +471,8 @@ class HloTransposeInstruction : public HloInstruction { class HloBroadcastInstruction : public HloInstruction { public: - explicit HloBroadcastInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension); + explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, + absl::Span broadcast_dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -434,8 +488,7 @@ class HloBroadcastInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -443,9 +496,9 @@ class HloBroadcastInstruction : public HloInstruction { class HloMapInstruction : public HloInstruction { public: - explicit HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); + explicit HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -454,7 +507,7 @@ class HloMapInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -463,8 +516,7 @@ class HloMapInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -473,9 +525,9 @@ class HloMapInstruction : public HloInstruction { class HloSliceInstruction : public HloInstruction { public: explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); HloInstructionProto ToProto() const override; @@ -514,8 +566,7 @@ class HloSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [begin, end) index range for a slice. @@ -529,13 +580,13 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: - explicit HloConstantInstruction(std::unique_ptr literal); + explicit HloConstantInstruction(Literal literal); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } // Returns whether there is literal associated with this instruction. - bool HasLiteral() const { return literal_ != nullptr; } + bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -547,7 +598,7 @@ class HloConstantInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -557,18 +608,16 @@ class HloConstantInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + absl::optional literal_; }; class HloTraceInstruction : public HloInstruction { public: explicit HloTraceInstruction(const string& tag, HloInstruction* operand); // Returns a tag to be used in tracing. - string TracingTag() const { return literal_->GetR1U8AsString(); } + string TracingTag() const { return literal_.GetR1U8AsString(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -579,11 +628,9 @@ class HloTraceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + Literal literal_; }; class HloFusionInstruction : public HloInstruction { @@ -591,10 +638,9 @@ class HloFusionInstruction : public HloInstruction { explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); - explicit HloFusionInstruction( - const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, - HloComputation* fusion_computation); + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + absl::Span operands, + HloComputation* fusion_computation); string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -698,7 +744,7 @@ class HloFusionInstruction : public HloInstruction { bool add_output = false); bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -707,8 +753,7 @@ class HloFusionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The type of the fusion. Used by kFusion only. @@ -717,9 +762,9 @@ class HloFusionInstruction : public HloInstruction { class HloRngInstruction : public HloInstruction { public: - explicit HloRngInstruction( - const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + explicit HloRngInstruction(const Shape& shape, + RandomDistribution distribution, + absl::Span parameters); // Returns the random distribution for this rng node. RandomDistribution random_distribution() const { return distribution_; } // Returns a serialized representation of this instruction. @@ -727,7 +772,7 @@ class HloRngInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -736,8 +781,7 @@ class HloRngInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The distribution requested for random number generation. @@ -762,8 +806,7 @@ class HloParameterInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 parameter_number_ = 0; @@ -787,8 +830,7 @@ class HloGetTupleElementInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 tuple_index_ = -1; @@ -816,8 +858,7 @@ class HloReducePrecisionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The bit sizes for a reduce-precision operation. @@ -830,10 +871,6 @@ class HloInfeedInstruction : public HloInstruction { explicit HloInfeedInstruction(const Shape& infeed_shape, HloInstruction* token_operand, const string& config); - // TODO(b/80000000): Remove this constructor when all uses of infeed are - // converted to take tokens. - explicit HloInfeedInstruction(const Shape& infeed_shape, - const string& config); // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. @@ -858,8 +895,7 @@ class HloInfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The string representation of the infeed configuration. @@ -871,13 +907,7 @@ class HloOutfeedInstruction : public HloInstruction { explicit HloOutfeedInstruction(const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, - tensorflow::StringPiece outfeed_config); - // TODO(b/80000000): Remove this constructor when all uses of outfeed are - // converted to take tokens. - explicit HloOutfeedInstruction(const Shape& outfeed_shape, - HloInstruction* operand, - tensorflow::StringPiece outfeed_config); - + absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); @@ -897,8 +927,7 @@ class HloOutfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Shape of outfeed request. @@ -911,8 +940,9 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -922,6 +952,19 @@ class HloConvolutionInstruction : public HloInstruction { const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = dnums; } + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const { return feature_group_count_; } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -935,12 +978,18 @@ class HloConvolutionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count_; + // Describes the window used for a convolution. Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -964,8 +1013,7 @@ class HloReduceWindowInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; @@ -1013,24 +1061,23 @@ class HloSelectAndScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target); + explicit HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target); const Window& window() const override { CHECK(window_ != nullptr); return *window_; } void set_window(const Window& window) override { - window_ = MakeUnique(window); + window_ = absl::make_unique(window); } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -1041,9 +1088,13 @@ class HloCustomCallInstruction : public HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = - MakeUnique(dnums); + absl::make_unique(dnums); } const string& custom_call_target() const { return custom_call_target_; } + void set_feature_group_count(int64 feature_group_count) { + feature_group_count_ = feature_group_count; + } + int64 feature_group_count() const { return feature_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1056,8 +1107,7 @@ class HloCustomCallInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1065,33 +1115,8 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr window_; // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; -}; - -class HloHostComputeInstruction : public HloInstruction { - public: - explicit HloHostComputeInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - const string& channel_name() const { return channel_name_; } - // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const override; - - private: - bool IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const override; - // Implementation for non-common logic of CloneWithNewOperands. - std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, - HloCloneContext* context) const override; - // Name to use for host send/recv channels. - string channel_name_; - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; + // The number of feature groups. This is used for grouped convolutions. + int64 feature_group_count_; }; class HloPadInstruction : public HloInstruction { @@ -1113,8 +1138,7 @@ class HloPadInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The padding configuration that describes the edge padding and interior @@ -1124,10 +1148,10 @@ class HloPadInstruction : public HloInstruction { class HloDynamicSliceInstruction : public HloInstruction { public: - explicit HloDynamicSliceInstruction( - const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + explicit HloDynamicSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1149,8 +1173,7 @@ class HloDynamicSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [start, start + size) range size for a dynamic slice @@ -1162,15 +1185,15 @@ class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; } - tensorflow::gtl::ArraySlice gather_window_bounds() const { - return gather_window_bounds_; + absl::Span gather_slice_sizes() const { + return gather_slice_sizes_; } // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; @@ -1179,10 +1202,9 @@ class HloGatherInstruction : public HloInstruction { // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim); + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim); private: std::vector ExtraAttributesToStringImpl( @@ -1192,12 +1214,11 @@ class HloGatherInstruction : public HloInstruction { const std::function& eq_computations) const override; std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr gather_dimension_numbers_; - std::vector gather_window_bounds_; + std::vector gather_slice_sizes_; }; class HloScatterInstruction : public HloInstruction { @@ -1218,9 +1239,9 @@ class HloScatterInstruction : public HloInstruction { // Creates an instance of ScatterDimensionNumbers. static ScatterDimensionNumbers MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim); private: @@ -1232,13 +1253,114 @@ class HloScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr scatter_dimension_numbers_; }; +class HloIotaInstruction : public HloInstruction { + public: + explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + int64 iota_dimension() const { return iota_dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + const int64 iota_dimension_; +}; + +class HloDotInstruction : public HloInstruction { + public: + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); + + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + return dot_dimension_numbers_; + } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + + // Describes the dimension numbers used for a dot. + DotDimensionNumbers dot_dimension_numbers_; + + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; +}; + +class HloDomainInstruction : public HloInstruction { + public: + explicit HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 71b44507cc704344ff6fe5269ea498bb32cfb8a6..d9be841dd751651ba029998fd062fcaec3691945 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,20 +17,20 @@ limitations under the License. #include +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { - -using ::tensorflow::StringPiece; - namespace { +using absl::string_view; + constexpr int kEOF = -1; constexpr int kError = -2; @@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -tensorflow::StringPiece HloLexer::StringPieceFromPointers( - const char* begin, const char* end) const { +absl::string_view HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return tensorflow::StringPiece(begin, end - begin); + return absl::string_view(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() { return TokKind::kLparen; case ')': return TokKind::kRparen; - case '/': - return LexComment(); + case '/': { + if (PeekCurrentChar() == '*') { + // This is the start of a /*...*/ delimited comment. Save the current + // location in case the comment is unterminated so the error message + // will point to the beginning of the comment. + const char* comment_start = current_ptr_; + current_ptr_++; + // Advance until '*/' is found. + while (true) { + int current = GetNextChar(); + if (current == '*' && PeekCurrentChar() == '/') { + // End of comment. + current_ptr_++; + break; + } + if (current == kEOF) { + // Unterminated comment. + current_ptr_ = comment_start; + return TokKind::kError; + } + } + // Return no token for the comment. Keep lexing. + continue; + } else if (PeekCurrentChar() == '/') { + // This is the start of a '//' delimited comment. Throw away + // everything until end of line or file. The end-of-line character(s) + // are left unlexed in the buffer which is harmless because these are + // skipped later by the lexer. This approach enables support for + // different end-of-line encodings. + while (true) { + int current = PeekCurrentChar(); + if (current == kEOF || current == '\n' || current == '\r') { + break; + } + current_ptr_++; + } + continue; + } + // A lone '/' is an error. + return TokKind::kError; + } case '"': return LexString(); } @@ -196,7 +235,7 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - tensorflow::StringPiece identifier = + absl::string_view identifier = StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. @@ -230,7 +269,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = std::string(identifier); + str_val_ = string(identifier); return TokKind::kIdent; } @@ -267,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_); + CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); return TokKind::kDecimal; } @@ -300,7 +338,7 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + if (absl::SimpleAtoi(slice, &int64_val_)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -326,6 +364,7 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no = line_no_cache_.line_no_of_query; } for (; ptr != location; ptr++) { + CHECK_LT(ptr, buf_.end()); if (*ptr == '\n') { line_no++; } @@ -335,38 +374,28 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == tensorflow::StringPiece::npos) { + if (line_offset == absl::string_view::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { +absl::string_view HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == tensorflow::StringPiece::npos + const char* start = line_start == absl::string_view::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); const char* end = - line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; + line_end == absl::string_view::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } -TokKind HloLexer::LexComment() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; - if (RE2::Consume(&consumable, *comment_pattern)) { - current_ptr_ = consumable.begin(); - return TokKind::kComment; - } - return TokKind::kError; -} - // Lexes quoted string with escaping characters. If matched, the quoted string // will be unescaped and stored to str_val_. TokKind HloLexer::LexString() { @@ -374,10 +403,10 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::StringPiece raw = + absl::string_view raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; - if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &str_val_, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } @@ -412,8 +441,6 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; - case TokKind::kComment: - return "kComment"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0..3e2f8bcd52f9043f161197756a2060b28dded1d9 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" @@ -34,7 +34,7 @@ namespace xla { // it directly. class HloLexer { public: - explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + explicit HloLexer(absl::string_view buf) : buf_(buf) { current_ptr_ = buf_.begin(); } @@ -77,7 +77,7 @@ class HloLexer { std::pair GetLineAndColumn(LocTy location) const; // Returns the whole line given the location. - tensorflow::StringPiece GetLine(LocTy loc) const; + absl::string_view GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -89,8 +89,8 @@ class HloLexer { // Creates StringPiece with the given begin and end. Exits if the begin > end, // or it's out of the range of the current buffer. - tensorflow::StringPiece StringPieceFromPointers(const char* begin, - const char* end) const; + absl::string_view StringPieceFromPointers(const char* begin, + const char* end) const; tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( const char* begin, const char* end) const; @@ -105,14 +105,13 @@ class HloLexer { TokKind LexShape(); TokKind LexConstant(); TokKind LexNumberOrPattern(); - TokKind LexComment(); TokKind LexString(); - const tensorflow::StringPiece buf_; + const absl::string_view buf_; const char* current_ptr_; // Information about the current token. - const char* token_start_; + const char* token_start_ = nullptr; TokKind current_kind_; string str_val_; Shape shape_val_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e..5bf055f3c012fef687cdc275d62efdf2d4cd5e5c 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,8 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,17 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { using Worklist = std::deque; using Workset = std::unordered_set; -namespace { - void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { if (workset->count(instruction) == 0) { @@ -221,6 +219,33 @@ void PropagateLivenessToParameterCallers( } } +// Makes sure that if a live instruction is within a computation used in control +// flow operations, we mark live even other related instructions. +void PropagateLivenessThroughControlFlow( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + HloInstruction* caller = callsite.instruction(); + if (caller->opcode() == HloOpcode::kWhile) { + // If a live instruction is within the %while body or condition + // computation, mark the predicate value returned by the condition + // computation live as well. + MarkLiveAtIndex(caller->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); + } else if (caller->opcode() == HloOpcode::kConditional) { + // If a live instruction is within the true or false branches of a + // conditional, we mark the predicate operand live as well. + MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist, + workset); + } + } + } +} + } // namespace HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) @@ -259,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() { } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kWhile && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kWhile) { PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kParameter) { PropagateLivenessToParameterCallers(instruction, &live_index_map_, &worklist, &workset, call_graph_.get()); @@ -279,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() { MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); } } + PropagateLivenessThroughControlFlow(instruction, &live_index_map_, + &worklist, &workset, call_graph_.get()); } } @@ -296,7 +321,7 @@ StatusOr> HloLivenessAnalysis::Run( VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module)); liveness_analysis->RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 01b625c29ca2823b2a2490b30a9d4d5128b4c22e..e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); } +TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + InnerWhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + InnerWhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + OuterWhileCondition { + cond_param.2 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 + constant.5 = s32[] constant(5) + ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + } + OuterWhileBody { + body_param.2 = (s32[]) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0 + constant.6 = s32[] constant(0) + tuple.2 = (s32[]) tuple(constant.6) + inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition, + body=InnerWhileBody + constant.7 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.8, constant.7) + ROOT rtuple = (s32[]) tuple(add.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=OuterWhileCondition, + body=OuterWhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 7e4b8834357d39099f76450b849d6b5624e4e3b4..5269cad94d35be3dd1c009588bbe422ff1533364 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -15,15 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { -using ::tensorflow::str_util::Join; - bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong lhs_contracting_dimensions (got {" - << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" - << lhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") + << "} want {" << lhs_contracting_dim_ << "})"; return false; } @@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong rhs_contracting_dimensions (got {" - << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" - << rhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") + << "} want {" << rhs_contracting_dim_ << "})"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index b57c940238f0672692e3b65827f43e2f5499502d..5502e565b6dfbaca6cfa2101950fb0a68c89771f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { namespace testing { @@ -120,8 +120,7 @@ class HloShapeAndLayoutMatcher class HloShardingMatcher : public ::testing::MatcherInterface { public: - explicit HloShardingMatcher( - const tensorflow::gtl::optional& sharding) + explicit HloShardingMatcher(const absl::optional& sharding) : sharding_(sharding) {} bool MatchAndExplain(const HloInstruction* instruction, @@ -129,7 +128,7 @@ class HloShardingMatcher void DescribeTo(std::ostream* os) const override; private: - tensorflow::gtl::optional sharding_; + absl::optional sharding_; }; // Matches a Dot HLO instruction with specific LHS and RHS contracting @@ -189,6 +188,7 @@ HLO_MATCHER(Fusion); HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); +HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); @@ -231,6 +231,7 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); +HLO_MATCHER(TupleSelect); HLO_MATCHER(While); // The special cases below let you check additional information about the @@ -306,7 +307,7 @@ inline ::testing::Matcher Shape( return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); } inline ::testing::Matcher Shape( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -316,7 +317,7 @@ inline ::testing::Matcher ShapeWithLayout( new ::xla::testing::HloShapeAndLayoutMatcher(shape)); } inline ::testing::Matcher ShapeWithLayout( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -329,14 +330,14 @@ inline ::testing::Matcher Sharding( } // Matcher for Sharding from sharding string inline ::testing::Matcher Sharding( - tensorflow::StringPiece sharding) { + absl::string_view sharding) { return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( ParseSharding(sharding).ValueOrDie())); } // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { return ::testing::MakeMatcher( - new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); + new ::xla::testing::HloShardingMatcher(absl::nullopt)); } inline ::testing::Matcher Dot( diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7de59acc1efbc0150b95ebdd85a13ede48eec2f9..7961aece541faeb66875885b380158756c503250 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -157,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) { Array assignment({2}); assignment.SetValues({0, 1}); auto sharding = HloSharding::Tuple( - tuple_shape, - {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), - HloSharding::AssignDevice(1), HloSharding::Replicate()}); + tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), + HloSharding::Replicate()}); p2->set_sharding(sharding); EXPECT_THAT(p0.get(), op::NoSharding()); @@ -172,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) { EXPECT_THAT( p2.get(), - op::Sharding( - "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}")); EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc similarity index 92% rename from tensorflow/compiler/xla/service/hlo_scheduling.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 27cc5361cde2fa021b9489f98217ae5648afc2ad..6a4e766788f47cad9e168fcccd3a3de9097cacdc 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include +#include #include #include @@ -28,16 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Class implementing a list scheduler of HLO instructions which produces a // sequence which minimizes memory usage by preferring to schedule the node that // frees bigger buffer and defines smaller outputs. @@ -71,7 +70,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr> Run( + static StatusOr Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -230,8 +229,8 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - std::vector CreateSchedule() { - std::vector schedule; + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; // Populate the ready list with instructions which have no operands or // control predecessors. @@ -375,7 +374,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> ScheduleComputationHelper( +StatusOr ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -393,7 +392,7 @@ StatusOr> ScheduleComputationHelper( } // namespace -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -401,7 +400,7 @@ StatusOr> DFSMemoryScheduler( memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); + int64 total_hlos = computation.parent()->instruction_count(); tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -444,7 +443,7 @@ StatusOr> DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - std::vector sequence; + HloInstructionSequence sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -464,7 +463,7 @@ StatusOr> DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -474,18 +473,16 @@ StatusOr> ListMemoryScheduler( memory_by_computation); } -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - const auto& post_order = computation.MakeInstructionPostOrder(); - return std::vector{post_order.begin(), - post_order.end()}; + return HloInstructionSequence(computation.MakeInstructionPostOrder()); } -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -500,7 +497,7 @@ StatusOr> DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - std::vector list_sequence, + HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -509,7 +506,7 @@ StatusOr> DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -519,7 +516,7 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - std::vector post_order_sequence, + HloInstructionSequence post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -546,32 +543,35 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); tensorflow::gtl::FlatMap memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, one_computation_sequence, *points_to_analysis, + *computation, computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - sequence[computation] = std::move(one_computation_sequence); + schedule.set_sequence(computation, std::move(computation_sequence)); } } - VLOG(1) << "Module schedule:\n" << sequence; - return sequence; + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); } -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); @@ -582,4 +582,22 @@ StatusOr> ScheduleOneComputation( size_function, nullptr, empty_map); } +HloMemoryScheduler::HloMemoryScheduler( + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) + : size_function_(size_function), algorithm_(algorithm) {} + +StatusOr HloMemoryScheduler::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + +StatusOr HloDescheduler::Run(HloModule* module) { + bool changed = module->has_schedule(); + module->clear_schedule(); + return changed; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h similarity index 62% rename from tensorflow/compiler/xla/service/hlo_scheduling.h rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 2b33ccc8bfb895286bb3747aab0a16cf25e2cfae..9964c6fdd7c60a807896ea7aaaa9d55767f20f51 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #include #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,14 +34,14 @@ namespace xla { // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. -typedef std::function>( +typedef std::function( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -47,7 +49,7 @@ StatusOr> ListMemoryScheduler( memory_by_computation); // DFS-order scheduler -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -55,7 +57,7 @@ StatusOr> DFSMemoryScheduler( memory_by_computation); // Naive Post Order scheduler -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -65,26 +67,57 @@ StatusOr> PostOrderMemoryScheduler( // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation); -// Returns an HloModuleSequence which seeks to minimize the memory required for +// Returns an HloSchedule which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloModulePass { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + StatusOr Run(HloModule* module) override; + + private: + LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModudle::has_schedule will return false. +class HloDescheduler : public HloModulePass { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + StatusOr Run(HloModule* module) override; +}; + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc similarity index 79% rename from tensorflow/compiler/xla/service/hlo_scheduling_test.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 9ec983c2bc353955cb23d441d200ac8aa36951b1..1b9e9bfc77c3ba91e5b878f4aa42d26d8267a49a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -28,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -64,21 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); + HloMemoryScheduler scheduler([](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + module->schedule().sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(module->schedule()); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); + + // Clear the schedule using the descheduling pass. + HloDescheduler descheduler; + EXPECT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed, + descheduler.Run(module.get())); + EXPECT_TRUE(descheduler_changed); + EXPECT_FALSE(module->has_schedule()); } TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { @@ -106,28 +122,26 @@ ENTRY root { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : - sequence.at(module->entry_computation())) { + for (const HloInstruction* instruction : sequence) { instructions_by_name[instruction->name()] = instruction; } // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), - sequence.at(module->entry_computation()).front()); - EXPECT_EQ(instructions_by_name.at("result"), - sequence.at(module->entry_computation()).back()); + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); // Instructions "d" and "e" will both be schedulable at the same time, but // instruction "d" allows us to free the buffer of "p1", so the list scheduler // should prefer it. - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), instructions_by_name.at("e"))); } @@ -218,13 +232,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(entry_computation).size()); + SequentialHloOrdering ordering(schedule); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been // better to schedule it first, instead of during the busy time. @@ -241,13 +255,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The max mem doesn't change - // because the while body isn't live during the peak. - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -267,7 +281,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto abs_abs1 = builder.AddInstruction( HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice({abs_abs1}))); + absl::Span({abs_abs1}))); auto tuple_elm = builder.AddInstruction( HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); @@ -279,19 +293,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // tuple allocates the tuple buffer and doesn't free anything. // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. // abs_abs2 should be scheduled before tuple by List. @@ -330,18 +343,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // fusion allocates memory for the tuple elements and doesn't free anything, // so it's more expensive than exp. EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); @@ -350,7 +363,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto module = CreateNewModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); // param != 0 // Needs 17 bytes @@ -390,12 +402,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); tensorflow::gtl::FlatMap memory_by_computation; memory_by_computation[cond_computation] = 17; @@ -405,12 +417,13 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations - EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + // HeapSimulator accounts for subcomputations. Cond is the largest one. + // The output buffer of the while is aliased. + EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 55ff073d3faf34aa0f1b8f0886946837e7a49bcc..b3949f3a6d7176950c61cafb0830d1175f17758d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -22,12 +22,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -49,9 +51,16 @@ StatusOr HloModule::LaunderConstInstructionFromModule( return const_cast(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -64,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal( } } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -96,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -113,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -197,12 +212,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -220,12 +246,18 @@ HloModuleProto HloModule::ToProto() const { } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } return proto; } /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -274,7 +306,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), module_config); + auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -289,25 +321,42 @@ StatusOr> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; + tensorflow::gtl::FlatSet computation_ids; + tensorflow::gtl::FlatSet instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } @@ -352,7 +401,7 @@ bool IsUsedOutsideSubcomputation( } // anonymous namespace HloInstruction* HloModule::OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation) { auto builder = HloComputation::Builder(outlined_computation_name); @@ -409,7 +458,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( string error_message = "The subcomputation to outline has multiple outputs:\n"; for (HloInstruction* output : outputs) { - tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); + absl::StrAppend(&error_message, output->ToString(), "\n"); } LOG(FATAL) << error_message; } @@ -507,7 +556,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix, config_); + auto module = absl::make_unique(name_ + "-" + suffix, config_); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); @@ -535,12 +584,11 @@ uint64 HloModule::RandomNew64() const { return rng_(); } -HloComputation* HloModule::GetComputationWithName( - tensorflow::StringPiece name) { +HloComputation* HloModule::GetComputationWithName(absl::string_view name) { auto computations_in_module = computations(); - auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { - return computation->name() == name; - }); + auto it = absl::c_find_if( + computations_in_module, + [&](HloComputation* computation) { return computation->name() == name; }); return it == computations_in_module.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d2e726a0db63f622cd5092d56b4f746232d04aad..735804e827afd77e2b7f2a4a7d490ee6f5ee7b4f 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -24,16 +24,18 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -61,6 +63,7 @@ class HloModule { // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. explicit HloModule(const string& name, const HloModuleConfig& config); + virtual ~HloModule() {} // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -85,6 +88,7 @@ class HloModule { const std::unordered_map& replacements); const string& name() const { return name_; } + void set_name(string name) { name_ = std::move(name); } // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; @@ -142,7 +146,7 @@ class HloModule { // Returns the computation in this module that has the name `name`. Returns // null if there is no such computation. - HloComputation* GetComputationWithName(tensorflow::StringPiece name); + HloComputation* GetComputationWithName(absl::string_view name); // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } @@ -192,7 +196,7 @@ class HloModule { // order (root of outlined instructions last). TODO(jingyue): takes a set of // instructions and topologically sorts them. HloInstruction* OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation); // Returns a randomly generated uint64. @@ -235,12 +239,25 @@ class HloModule { StatusOr LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names); + bool uniquify_identifiers); - const string name_; + string name_; HloModuleConfig config_; HloComputation* entry_computation_ = nullptr; std::vector> computations_; @@ -262,6 +279,11 @@ class HloModule { static std::atomic next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional schedule_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b..9bfa3a5f45c8e810f9ea7d6bdcd72b90254d15b9 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrAppend; +using absl::StrAppend; HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, bool ignore_layouts) @@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); + string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } - StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 074e9c90705d432b8344aebaf3c15aeb41a59fa3..68c18836eb01484b819e7b7bd26f099dcf56e7ba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -72,15 +72,6 @@ class HloModuleConfig { return debug_options_.xla_hlo_profile(); } - // Sets/returns whether this is a "host module". Host modules are used to - // record the data- and control-flow dependencies of host side computation - // that communicates with compiled code. They are used for analysis and - // scheduling purposes, but no code is generated. - bool is_host_module() const { return is_host_module_; } - void set_is_host_module(bool is_host_module) { - is_host_module_ = is_host_module; - } - // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } @@ -113,10 +104,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional entry_computation_layout_; - - // Whether this is a 'host module'. - bool is_host_module_ = false; + absl::optional entry_computation_layout_; // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 98d20315e399c6b1a3979b5d11a89ef93869f4d9..31d26cc51e8217234526bbfeb83510aadf2c27b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -36,23 +36,6 @@ namespace xla { namespace { -bool HasSendRecv(HloComputation* computation) { - for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kSendDone || - instruction->opcode() == HloOpcode::kRecv || - instruction->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (auto* sub_computation : instruction->called_computations()) { - if (HasSendRecv(sub_computation)) { - return true; - } - } - } - return false; -} - StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { bool changed = false; for (auto* computation : module->computations()) { @@ -67,10 +50,9 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_root = while_body_comp->root_instruction(); if (!ShapeUtil::IsTuple(xla_while->shape()) || - while_body_root->opcode() != HloOpcode::kTuple || - HasSendRecv(while_body_comp)) { + while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, - // with no send/recv instructions. + // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); continue; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 29024085c1038961ef2b3721de1ce0e8a55ccf45..d472211d2af6e4b583d3815146ba8cee5c8e7495 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -28,10 +28,10 @@ namespace xla { // Sweeps through live instructions which cross computation boundaries (kWhile), // and removes code at dead shape indices. // -class HloModuleDCE : public HloPassInterface { +class HloModuleDCE : public HloModulePass { public: ~HloModuleDCE() override {} - tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + absl::string_view name() const override { return "hlo-module-dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 363862e4905fc13a4ef07aeaac255259fc6b86ba..bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { "while.2", 1)); } +// Tests that a while whose body has outfeed operations is not DCE-ed. +TEST_F(HloModuleDceTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + +// Tests that if a loop variable is not referenced outside of a kWhile, the loop +// variable changes are not elided within the loop body, if the condition +// computation uses them. +TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { + auto module = ParseHloString(R"( + HloModule InfiniteLoop + WhileBody { + body_param = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2) + } + WhileCondition { + cond_param = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + p0 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(p0), index=0 + constant.3 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5) + while = (s32[], s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9b56ef4643f2ca88e56456ae6c990161adb5085 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +namespace xla { + +HloModuleGroup::HloModuleGroup(absl::string_view name, + std::unique_ptr module) + : name_(name) { + push_back(std::move(module)); +} + +HloModuleGroup::HloModuleGroup(absl::string_view name, + absl::Span> modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + +std::vector> HloModuleGroup::ConsumeModules() { + std::vector> ret_modules = std::move(modules_); + + // Clear everything so the object state is in a known (empty) state. + modules_.clear(); + module_ptrs_.clear(); + return ret_modules; +} + +string HloModuleGroup::ToString() const { + std::ostringstream s; + s << "HloModuleGroup " << name() << "\n\n"; + for (const HloModule* module : modules()) { + s << module->ToString() << "\n"; + } + return s.str(); +} + +HloModuleGroupProto HloModuleGroup::ToProto() const { + HloModuleGroupProto proto; + proto.set_name(name()); + for (const HloModule* module : modules()) { + *proto.add_hlo_modules() = module->ToProto(); + } + return proto; +} + +/* static */ StatusOr HloModuleGroup::CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs) { + TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; + TF_RET_CHECK(proto.hlo_modules_size() > 0) + << "Module group must have at least one HLO module"; + TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size()); + + std::vector> modules; + for (int i = 0; i < proto.hlo_modules_size(); ++i) { + const HloModuleProto& module_proto = proto.hlo_modules(i); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(module_proto, module_configs[i])); + modules.push_back(std::move(module)); + } + + return HloModuleGroup(proto.name(), absl::MakeSpan(modules)); +} + +void HloModuleGroup::push_back(std::unique_ptr module) { + modules_.push_back(std::move(module)); + module_ptrs_.push_back(modules_.back().get()); +} + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { + out << group.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h new file mode 100644 index 0000000000000000000000000000000000000000..7338be8b9c5ed47f0ba5829cc1d603b21f00b6e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// An abstraction representing a ordered set of HLO module built to run +// concurrently across different devices. +class HloModuleGroup { + public: + // Construct an empty module group. + explicit HloModuleGroup(absl::string_view name) : name_(name) {} + + // Construct a module group containing a single module. + HloModuleGroup(absl::string_view name, std::unique_ptr module); + + // Construct a module group containing any number of modules. + HloModuleGroup(absl::string_view name, + absl::Span> modules); + + // Returns the modules contained in the group. + const std::vector& modules() const { return module_ptrs_; } + + // Returns a module at a particular index. + HloModule& module(int index) const { return *module_ptrs_.at(index); } + + // Add a module to the back of vector of modules in the group. + void push_back(std::unique_ptr module); + + // Moves all modules from the group into the returned vector. After this + // method runs, the module group will be empty. + std::vector> ConsumeModules(); + + string name() const { return name_; } + string ToString() const; + + // Serialize the module group to/from a proto. + HloModuleGroupProto ToProto() const; + static StatusOr CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs); + + private: + string name_; + + // Vector of modules as std::unique_ptrs. + std::vector> modules_; + + // Vector of modules as normal pointers. This vector is kept in sync with + // modules_ as modules are added to the group with push_back. + std::vector module_ptrs_; +}; + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 10bf9ffd6c1960df5ca2a3555d120b0874407f15..83352ef91b35b61ee2560b1488ee2ecdff6bea0a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -19,9 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -59,7 +60,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = MakeUnique(modules); + auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() { if (VLOG_IS_ON(4)) { DumpCollectedStats(); } + + for (HloModule* module : modules_) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + points_to_analyses_[module] = std::move(points_to_analysis); + } + return Status::OK(); } @@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { ss << " " << hlo->name() << std::endl; } ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); + return FailedPrecondition("%s", ss.str()); } } } @@ -204,6 +213,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( return channels_[channel_id_map_.at(channel_id)]; } +bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const { + return channel_id_map_.find(channel_id) != channel_id_map_.end(); +} + HloComputation* HloModuleGroupMetadata::PeerComputation( const HloInstruction* instruction) const { CHECK(IsChannelInstruction(instruction)); @@ -267,15 +280,14 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } -tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( +absl::optional HloModuleGroupMetadata::GetInstructionDevice( const HloInstruction& instruction) const { // The module group metadata can be created in both "single module, multiple // devices" and "multiple modules, no explicit devices" fashions. // The API returns an optional even though the current implementation always // returns a device, to account for cases where we cannot guess a device. // In such cases the VerifyChannelInstructions() will return proper errors. - tensorflow::gtl::optional device = - instruction.sharding_unique_device(); + absl::optional device = instruction.sharding_unique_device(); if (!device) { device = GetModuleId(instruction.parent()->parent()); } @@ -283,10 +295,7 @@ tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( } int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { - return std::count_if(modules_.begin(), modules_.end(), - [](const HloModule* module) { - return !module->config().is_host_module(); - }); + return modules_.size(); } Status HloModuleGroupMetadata::RecordInstructions() { @@ -383,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - tensorflow::MakeUnique>()); + absl::make_unique>()); auto companion_set = companion_sets_.back().get(); - companion_set->insert(instruction1); - companion_set->insert(instruction2); + companion_set->push_back(instruction1); + companion_set->push_back(instruction2); companion_set_index_[instruction1] = companion_sets_.size() - 1; companion_set_index_[instruction2] = companion_sets_.size() - 1; } else if (!ContainsKey(companion_set_index_, instruction1)) { - companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_sets_[companion_set_index_[instruction2]]->push_back( + instruction1); companion_set_index_[instruction1] = companion_set_index_[instruction2]; } else if (!ContainsKey(companion_set_index_, instruction2)) { - companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_sets_[companion_set_index_[instruction1]]->push_back( + instruction2); companion_set_index_[instruction2] = companion_set_index_[instruction1]; } else if (companion_set_index_[instruction1] != companion_set_index_[instruction2]) { - companion_sets_[companion_set_index_[instruction1]]->insert( - Companions(instruction2).begin(), Companions(instruction2).end()); + // At any point while building the companion sets, each instruction belongs + // to at most 1 companion set, so the union of two companion sets is + // concatenating two disjoint sets. + absl::c_copy(Companions(instruction2), + std::back_inserter( + *companion_sets_[companion_set_index_[instruction1]])); int64 index_to_remove = companion_set_index_[instruction2]; for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; @@ -411,16 +426,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, Status HloModuleGroupMetadata::VerifyChannelInstructions() { for (const Channel& channel : channels_) { if (channel.send == nullptr) { - return FailedPrecondition("missing send for id : %lld", channel.id); + return FailedPrecondition("missing send for id : %d", channel.id); } if (channel.recv == nullptr) { - return FailedPrecondition("missing recv for id : %lld", channel.id); + return FailedPrecondition("missing recv for id : %d", channel.id); } if (channel.send_done == nullptr) { - return FailedPrecondition("missing send-done for id : %lld", channel.id); + return FailedPrecondition("missing send-done for id : %d", channel.id); } if (channel.recv_done == nullptr) { - return FailedPrecondition("missing recv-done for id : %lld", channel.id); + return FailedPrecondition("missing recv-done for id : %d", channel.id); } } @@ -436,33 +451,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { auto send_done_device = GetInstructionDevice(*channel.send_done); if (!send_device) { return FailedPrecondition("send instruction must have a device: %s", - channel.send->ToString().c_str()); + channel.send->ToString()); } if (!send_done_device) { return FailedPrecondition("send_done instruction must have a device: %s", - channel.send_done->ToString().c_str()); + channel.send_done->ToString()); } if (*send_device != *send_done_device) { return FailedPrecondition( - "send and send-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "send and send-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *send_device, *send_done_device); } auto recv_device = GetInstructionDevice(*channel.recv); auto recv_done_device = GetInstructionDevice(*channel.recv_done); if (!recv_done_device) { return FailedPrecondition("recv_done instruction must have a device: %s", - channel.recv_done->ToString().c_str()); + channel.recv_done->ToString()); } if (*recv_device != *recv_done_device) { return FailedPrecondition( - "recv and recv-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "recv and recv-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *recv_device, *recv_done_device); } if (*send_device == *recv_device) { return FailedPrecondition( - "send and recv (channel=%lld) must be on different devices: %lld", + "send and recv (channel=%d) must be on different devices: %d", channel.id, *send_device); } } @@ -483,7 +498,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { !CheckCompanionPathsCompatibility( path, GetCompanionsPath(channel.recv_done))) { return FailedPrecondition( - "Nest companion paths do not match for channel %lld", channel.id); + "Nest companion paths do not match for channel %d", channel.id); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 84f2d3f5fbc1a6ff1df8ba3c0babd122e5701148..278d94cdd337c835bc0ff98ea577ef7b8c3ddd03 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,14 +22,15 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -125,6 +126,9 @@ class HloModuleGroupMetadata { // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns if the given channel id exists in metadata. + bool HasChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. const std::vector& GetAllReduceGroup( int64 all_reduce_id) const; @@ -156,7 +160,7 @@ class HloModuleGroupMetadata { // Retrieves the device an instruction is assigned to. Either from the // sharding information, or from the ordinal of the module the instruction // is in. - tensorflow::gtl::optional GetInstructionDevice( + absl::optional GetInstructionDevice( const HloInstruction& instruction) const; // Returns the number of modules for devices (excluding the host module). @@ -165,14 +169,14 @@ class HloModuleGroupMetadata { // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. - const std::unordered_set& Companions( - HloInstruction* instruction) const { + const std::vector& Companions( + const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } // Returns the companion set at the given index. - const std::unordered_set& companion_set(int64 index) const { + const std::vector& companion_set(int64 index) const { CHECK_LT(index, companion_sets_.size()); return *companion_sets_[index]; } @@ -183,7 +187,7 @@ class HloModuleGroupMetadata { } // Returns the list of all companion sets in the HLO module group. - const std::vector>>& + const std::vector>>& companion_sets() const { return companion_sets_; } @@ -194,6 +198,10 @@ class HloModuleGroupMetadata { // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } + TuplePointsToAnalysis* points_to_analysis(HloModule* module) const { + return points_to_analyses_.at(module).get(); + } + private: Status Build(); @@ -239,11 +247,10 @@ class HloModuleGroupMetadata { void DumpCollectedStats() const; // List of all companion instructions sets in the module. - std::vector>> - companion_sets_; + std::vector>> companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + tensorflow::gtl::FlatMap companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). tensorflow::gtl::FlatMap @@ -268,6 +275,9 @@ class HloModuleGroupMetadata { // The modules that this metadata was built from. const std::vector& modules_; + + tensorflow::gtl::FlatMap> + points_to_analyses_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7b12cb72b8df4610b964fb842da78e160d22d9f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class HloModuleGroupTest : public HloTestBase { + protected: + HloModuleGroupTest() = default; +}; + +TEST_F(HloModuleGroupTest, SingleModule) { + const string text = R"( +HloModule simple_module + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + HloModuleGroup group(TestName(), std::move(module)); + + EXPECT_EQ(group.modules().size(), 1); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config()})); + EXPECT_EQ(group_copy.modules().size(), 1); + EXPECT_THAT( + group_copy.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + std::vector> modules = group.ConsumeModules(); + EXPECT_EQ(modules.size(), 1); + EXPECT_EQ(group.modules().size(), 0); +} + +TEST_F(HloModuleGroupTest, MultipleModules) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + std::vector> modules; + modules.push_back(std::move(module_0)); + modules.push_back(std::move(module_1)); + HloModuleGroup group(TestName(), absl::MakeSpan(modules)); + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config(), + group.module(1).config()})); + EXPECT_EQ(group_copy.modules().size(), 2); +} + +TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + HloModuleGroup group(TestName()); + group.push_back(std::move(module_0)); + group.push_back(std::move(module_1)); + + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); +} + +// Tests that the order of companion instructions in the companion set doesn't +// change across runs. +TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) { + // A simple while loop template for core i sending to core i+1. + constexpr char text[] = R"( +HloModule module_%d + +while_cond { + ROOT p = pred[] constant(true) +} + +while_body { + param = s32[] parameter(0) + token.s = token[] after-all() + token.r = token[] after-all() + send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d + send-done = token[] send-done(send), channel_id=%d + recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d + ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d +} + +ENTRY entry { + while_init = s32[] constant(1) + ROOT while = s32[] while(while_init), condition=while_cond, body=while_body +} +)"; + + // Try creating the module and the metadata kTrialCount times and check the + // companion instructions remain in the same order. + const int64 kTrialCount = 5; + const int64 kDeviceCount = 10; + std::vector companion_order; + + for (int64 t = 0; t < kTrialCount; ++t) { + HloModuleGroup group(TestName()); + for (int64 i = 0; i < kDeviceCount; ++i) { + const int64 send_channel = i; + const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(absl::StrFormat(text, i, send_channel, send_channel, + recv_channel, recv_channel))); + group.push_back(std::move(module)); + } + ASSERT_EQ(group.modules().size(), kDeviceCount); + + TF_ASSERT_OK_AND_ASSIGN(auto metadata, + HloModuleGroupMetadata::Build(group.modules())); + ASSERT_EQ(metadata->companion_sets().size(), 1); + + std::vector module_ids; + for (HloInstruction* companion : *metadata->companion_sets()[0]) { + module_ids.push_back(metadata->GetModuleId(companion->GetModule())); + } + + if (t == 0) { + companion_order = module_ids; + } else { + EXPECT_TRUE(absl::c_equal(companion_order, module_ids)); + } + } +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9fd0ade153109c6c809c37aa08257f83a82c44d5..d83ee714905252e36f38438e81002a4d6ba7dafa 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,14 +22,17 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,24 +40,38 @@ namespace xla { std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { - std::vector predecessors; - - // Adds to the unique predecessors list and also add companion instructions - // if the given predecessor has those. + std::vector + predecessors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; + + // Adds to the unique predecessors list; if the predecessors is a companion + // instruction, also add companion instructions; if the predecessors is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_predecessor = [&](HloInstruction* predecessor) { - if (std::find(predecessors.begin(), predecessors.end(), predecessor) != - predecessors.end()) { + if (unique.find(predecessor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(predecessor)) { - predecessors.push_back(predecessor); + if (metadata_.IsCompanionInstruction(predecessor)) { + for (HloInstruction* instr : metadata_.Companions(predecessor)) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(predecessor)) { - predecessors.push_back(companion); + if (predecessor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } + return; } + unique.insert(predecessor); + predecessors.push_back(predecessor); }; - // If the given instruction is a companion instruction, we need to find the // predecessors of all of its companion instructions. If the instruction is an // all-reduce, we need to find the predecessors of all the peer all-reduce @@ -79,12 +96,14 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( add_unique_predecessor(control_predecessor); } } - if (instruction->opcode() == HloOpcode::kRecvDone) { + if (instruction->opcode() == HloOpcode::kRecvDone && + !DynCast(instruction)->is_host_transfer()) { // Send is a remote predecessor of RecvDone. HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_predecessor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast(instruction)->is_host_transfer()) { // Recv is a remote predecessor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -98,22 +117,37 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { - std::vector successors; - - // Adds to the unique successors list and also add companion instructions - // if the given successor has those. + std::vector + successors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; + + // Adds to the unique successors list; if the successor is a companion + // instruction, also add companion instructions; if the successor is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_successor = [&](HloInstruction* successor) { - if (std::find(successors.begin(), successors.end(), successor) != - successors.end()) { + if (unique.find(successor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(successor)) { - successors.push_back(successor); + if (metadata_.IsCompanionInstruction(successor)) { + for (HloInstruction* instr : metadata_.Companions(successor)) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(successor)) { - successors.push_back(companion); + if (successor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } + return; } + unique.insert(successor); + successors.push_back(successor); }; // If the given instruction is a companion instruction, we need to find the @@ -140,14 +174,16 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( add_unique_successor(control_successor); } } - if (instruction->opcode() == HloOpcode::kRecv) { + if (instruction->opcode() == HloOpcode::kRecv && + !DynCast(instruction)->is_host_transfer()) { // Send is a remote successor of Recv. const HloInstruction* recv_done = instruction->users().front(); CHECK(recv_done->opcode() == HloOpcode::kRecvDone); HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_successor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast(instruction)->is_host_transfer()) { // RecvDone is a remote successor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -157,7 +193,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( } std::vector HloModuleGroupUtil::RootInstructions( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector roots; for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { @@ -234,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( string cyclic_instructions; for (const auto& state : *visit_state) { if (state.second == VisitState::kVisiting) { - tensorflow::strings::StrAppend(&cyclic_instructions, - state.first->ToString(), "\n"); + absl::StrAppend(&cyclic_instructions, state.first->ToString(), + "\n"); } } // TODO(b/64305524): Improve the error message to print out the @@ -246,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( "following nodes. Note that the order of the nodes is arbitrary " "and that the list may include nodes that are not part of the " "cycle.\n%s", - predecessor->ToString().c_str(), cyclic_instructions.c_str()); + predecessor->ToString(), cyclic_instructions); } stack.push(predecessor); } @@ -257,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( } Status HloModuleGroupUtil::VerifyComputations( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { @@ -288,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector post_order; auto visit_function = [&](HloInstruction* instruction, @@ -302,7 +338,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = MakeUnique(post_order); + auto reachability = absl::make_unique(post_order); for (HloInstruction* hlo : post_order) { reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index c25ca1aff50b288f3ac3885cbed53e7ba9768430..309c23045d1e0dd91e2f245d00c51d9bf9961bf5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -56,7 +56,7 @@ class HloModuleGroupUtil { // Returns the root instructions of the computations. std::vector RootInstructions( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Visit state of each instruction during DFS traversal. enum VisitState { @@ -93,15 +93,14 @@ class HloModuleGroupUtil { HloInstruction* root); // Verifies that the computations are well-formed (e.g., no cycles). - Status VerifyComputations( - tensorflow::gtl::ArraySlice computations); + Status VerifyComputations(absl::Span computations); // Below Reachability utils resemble those in HloComputation, except that // they can handle instructions across multiple computations. // // Creates the reachability map for the instructions in the computations. StatusOr> ComputeReachability( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Updates the reachability of the given instruction, taking the global // predeccessorss and successors into account. diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 236f4500860a8673e61cbd2f861a8fc40c7861f7..39f38b417ab0e8b54864176d8d1e0ad1a422eca6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,21 +15,26 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" - +#include "tensorflow/core/lib/core/status_test_util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -44,7 +49,7 @@ class HloModuleTest : public HloTestBase { // Creates a computation which calls the given zero-parameter computations. std::unique_ptr CreateCallComputation( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto builder = HloComputation::Builder("Call"); for (auto computation : computations) { builder.AddInstruction( @@ -194,6 +199,153 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + +TEST_F(HloModuleTest, ProtoSerializationPreservesIds) { + // Verify that serializing then deserializing an HLO proto preserves the + // unique IDs of the instruction and module. + const string text = + R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + + // Perform various transformations on the graph: + // + // * clone the reduction function + // * replace use of reduction function with the clone. + // * add a random instruction to the entry computation. + // + // This will create instruction and computation IDs which are interesting: + // not consecutive and not densely packed. + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + HloComputation* reduction = root->to_apply(); + HloComputation* reduction_clone = + module->AddEmbeddedComputation(reduction->Clone()); + root->set_to_apply(reduction_clone); + TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction)); + HloInstruction* negate = entry->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root)); + entry->set_root_instruction(negate); + + // Schedule the transformed module, this verifies that the serialized schedule + // is robust against non-consecutive IDs as well (b/114712358). + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + HloMemoryScheduler scheduler(size_fn); + TF_ASSERT_OK(scheduler.Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + + // Serialize and deserialize and verify that the instruction and computations + // unique ids are the same. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + + // The module IDs should *not* be the same because module ids must be globally + // unique. + EXPECT_NE(module->unique_id(), module_copy->unique_id()); + + // Verify that the computations and instructions all have the same unique id. + auto computation_copy_it = module_copy->computations().begin(); + for (const HloComputation* computation_orig : module->computations()) { + const HloComputation* computation_copy = *computation_copy_it++; + EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id()) + << absl::StrFormat( + "ID of original computation %s != ID of deserialized " + "computation %s: %d != %d", + computation_orig->name(), computation_copy->name(), + computation_orig->unique_id(), computation_copy->unique_id()); + + auto instruction_copy_it = computation_copy->instructions().begin(); + for (const HloInstruction* instruction_orig : + computation_orig->instructions()) { + const HloInstruction* instruction_copy = *instruction_copy_it++; + EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id()) + << absl::StrFormat( + "ID of original instruction %s != ID of deserialized " + "instruction %s: %d != %d", + instruction_orig->name(), instruction_copy->name(), + instruction_orig->unique_id(), instruction_copy->unique_id()); + } + } + + // Verify that the next unique ID which the module would have handed out is + // greater than the unique id of any instruction. + int next_id = module_copy->NewUniqueInstructionId(); + for (const HloComputation* computation : module_copy->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_GT(next_id, instruction->unique_id()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d1eaf357855205f1e9867e86f3042b96b6beff97..2d4e38589fe4693e73c46d6c82e51cb0a8388f85 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -39,7 +39,7 @@ StatusOr StringToHloOpcode(const string& opcode_name) { }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { - return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + return InvalidArgument("Unknown opcode: %s", opcode_name); } return it->second; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 88531b6f209380a3f1bffe4e78da960b6811d9fd..e6bfb8025d4bfeba1d334d1f946e33841a2da092 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,7 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ V(kBatchNormInference, "batch-norm-inference") \ @@ -57,6 +58,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kCollectivePermute, "collective-permute") \ V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ @@ -84,7 +86,6 @@ namespace xla { V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ - V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIota, "iota") \ @@ -155,7 +156,7 @@ enum HloOpcodeProperty { // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); -// Returns a string representation of the opcode. +// Retrieves the opcode enum by name if the opcode exists. StatusOr StringToHloOpcode(const string& opcode_name); inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a..f1dc08bafa17a2dd68a7e922d4b84658bbf2589c 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -25,8 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -252,14 +253,36 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { + if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), + use.instruction)) { + continue; + } if (!UseIsBeforeValueDefinition(use, b, dataflow)) { VLOG(4) << "use of " << a << " (" << use << ") not before " << b << " is defined"; return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -270,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} @@ -302,22 +308,20 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); for (auto* computation : module_->MakeNonfusionComputations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s:", - computation->name().c_str())); + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { - pieces.push_back(tensorflow::strings::Printf( - " %s predecessors:", instruction->name().c_str())); + pieces.push_back( + absl::StrFormat(" %s predecessors:", instruction->name())); for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back( - tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) @@ -334,15 +338,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -360,50 +373,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s order:", - computation->name().c_str())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back( - tensorflow::strings::Printf(" %s", instruction->name().c_str())); - } - } - return tensorflow::str_util::Join(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a..b0361c3f02922bcaa14d52ad3b240701080f9b58 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -71,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. @@ -183,17 +180,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +191,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -217,10 +207,6 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c70bff1d2a022e395652049768d6d21..00970bcda34209d33867099d0bcf3b2902d52ae8 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -376,5 +377,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 93cc884e3a04a15eae927e1b8b9251c1d82290ad..37197b273ba09200dbf4dd04c6b7c4cacc068120 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,74 +15,100 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace { -using ::tensorflow::StringPiece; -using ::tensorflow::gtl::optional; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsInts; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(StringPiece str, const HloModuleConfig& config) - : lexer_(str), config_(config) {} - - // Runs the parser. Returns false if an error occurred. - bool Run(); + explicit HloParser(absl::string_view str) : lexer_(str) {} - // Returns the parsed HloModule. - std::unique_ptr ConsumeHloModule() { return std::move(module_); } + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns false if an error occurred. + bool Run(HloModule* module); // Returns the error information. - string GetError() const { return Join(error_, "\n"); } + string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); + StatusOr ParsePaddingConfigOnly(); + + // Stand-alone parsing utility for a single instruction worth of text. + Status ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name); private: + // Locates an instruction with the given name in the instruction_pool_ or + // returns nullptr. + // + // If the missing_instruction_hook_ is registered and a "shape" is provided, + // the hook will be called and may satisfy the request for the given + // instruction. This is useful when we reify parameters as they're resolved; + // i.e. for ParseSingleInstruction. + std::pair* FindInstruction( + const string& name, const optional& shape = nullopt); + // ParseXXX returns false if an error occurred. - bool ParseHloModule(); - bool ParseComputations(); + bool ParseHloModule(HloModule* module); + bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template - bool ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -125,6 +151,7 @@ class HloParser { kFloat, kString, kBracedInt64List, + kBracedInt64ListList, kHloComputation, kFftType, kWindow, @@ -137,6 +164,7 @@ class HloParser { kFusionKind, kDistribution, kDomain, + kPrecisionList, }; struct AttrConfig { @@ -202,9 +230,14 @@ class HloParser { bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); + bool ParsePrecisionList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + // 'parse_and_add_item' is an lambda to parse an element in the list and add + // the parsed element to the result. It's supposed to capture the result. + bool ParseList(const TokKind start, const TokKind end, const TokKind delim, + const std::function& parse_and_add_item); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -216,6 +249,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); + bool ParsePrecision(PrecisionConfig::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -228,8 +262,8 @@ class HloParser { bool CanBeParamListToShape(); // Logs the current parsing line and the given message. Always returns false. - bool TokenError(StringPiece msg); - bool Error(LocTy loc, StringPiece msg); + bool TokenError(absl::string_view msg); + bool Error(LocTy loc, absl::string_view msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -256,38 +290,78 @@ class HloParser { computation_pool_; HloLexer lexer_; - std::unique_ptr module_; std::vector> computations_; - const HloModuleConfig config_; std::vector error_; + + // Function that gets invoked when we try to resolve an instruction + // instruction_pool_ but fail to do so. + std::function*(string, + const optional&)> + missing_instruction_hook_; }; -bool HloParser::Error(LocTy loc, StringPiece msg) { +bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { + for (const auto& split : absl::StrSplit(s, delim)) { + int64 val; + if (!absl::SimpleAtoi(split, &val)) { + return false; + } + out->push_back(val); + } + return true; +} + +// Creates replica groups from the provided nested array. groups[i] represents +// the replica ids for group 'i'. +std::vector CreateReplicaGroups( + absl::Span> groups) { + std::vector replica_groups; + absl::c_transform(groups, std::back_inserter(replica_groups), + [](const std::vector& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + return replica_groups; +} + +bool HloParser::Error(LocTy loc, absl::string_view msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; const unsigned col = line_col.second; std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(std::string(lexer_.GetLine(loc))); + error_lines.emplace_back(lexer_.GetLine(loc)); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(Join(error_lines, "\n")); + error_.push_back(StrJoin(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } -bool HloParser::TokenError(StringPiece msg) { +bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -bool HloParser::Run() { +bool HloParser::Run(HloModule* module) { lexer_.Lex(); - return ParseHloModule(); + return ParseHloModule(module); +} + +std::pair* HloParser::FindInstruction( + const string& name, const optional& shape) { + std::pair* instr = + tensorflow::gtl::FindOrNull(instruction_pool_, name); + // Potentially call the missing instruction hook. + if (instr == nullptr && missing_instruction_hook_ != nullptr) { + return missing_instruction_hook_(name, shape); + } + return instr; } // ::= 'HloModule' name computations -bool HloParser::ParseHloModule() { +bool HloParser::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { return TokenError("expects HloModule"); } @@ -299,13 +373,27 @@ bool HloParser::ParseHloModule() { return false; } - module_ = MakeUnique(name, config_); + absl::optional is_scheduled; + std::unordered_map attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + + module->set_name(name); + if (!ParseComputations(module)) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module))); + } - return ParseComputations(); + return true; } // computations ::= (computation)+ -bool HloParser::ParseComputations() { +bool HloParser::ParseComputations(HloModule* module) { HloComputation* entry_computation = nullptr; do { if (!ParseComputation(&entry_computation)) { @@ -321,21 +409,20 @@ bool HloParser::ParseComputations() { if ((entry_computation != nullptr && computations_[i].get() != entry_computation) || (entry_computation == nullptr && i != computations_.size() - 1)) { - module_->AddEmbeddedComputation(std::move(computations_[i])); + module->AddEmbeddedComputation(std::move(computations_[i])); continue; } - auto computation = - module_->AddEntryComputation(std::move(computations_[i])); + auto computation = module->AddEntryComputation(std::move(computations_[i])); // The parameters and result layouts were set to default layout. Here we // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } @@ -352,7 +439,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = MakeUnique(name); + auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -365,8 +452,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - std::pair* root_node = - tensorflow::gtl::FindOrNull(instruction_pool_, root_name); + std::pair* root_node = FindInstruction(root_name); // This means some instruction was marked as ROOT but we didn't find it in the // pool, which should not happen. if (!root_name.empty() && root_node == nullptr) { @@ -480,7 +566,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -493,11 +579,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { + optional iota_dimension; + attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, + &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + instruction = builder->AddInstruction( + HloInstruction::CreateIota(shape, *iota_dimension)); break; } // Unary ops. @@ -592,31 +682,66 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional>> tmp_groups; optional to_apply; optional> replica_group_ids; optional barrier; optional all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - attrs["replica_group_ids"] = { - /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "", all_reduce_id)); - } else { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "", - all_reduce_id)); + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); } + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, replica_groups, + barrier ? *barrier : "", all_reduce_id)); + break; + } + case HloOpcode::kAllToAll: { + optional>> tmp_groups; + optional barrier; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); + } + instruction = builder->AddInstruction( + HloInstruction::CreateAllToAll(shape, operands, replica_groups)); + break; + } + case HloOpcode::kCollectivePermute: { + optional>> source_targets; + attrs["source_target_pairs"] = { + /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + std::vector> pairs(source_targets->size()); + for (int i = 0; i < pairs.size(); i++) { + if ((*source_targets)[i].size() != 2) { + return TokenError( + "expects 'source_target_pairs=' to be a list of pairs"); + } + pairs[i].first = (*source_targets)[i][0]; + pairs[i].second = (*source_targets)[i][1]; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } case HloOpcode::kReshape: { @@ -798,9 +923,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kConvolution: { optional window; optional dnums; + optional feature_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -808,8 +939,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!window) { window.emplace(); } + if (!feature_group_count) { + feature_group_count = 1; + } + PrecisionConfig precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfig::DEFAULT); + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], + feature_group_count.value(), *window, *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -882,11 +1025,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateReduce( shape, /*operands=*/ - tensorflow::gtl::ArraySlice(operands, 0, - operands.size() / 2), + absl::Span(operands).subspan( + 0, operands.size() / 2), /*init_values=*/ - tensorflow::gtl::ArraySlice( - operands, operands.size() / 2, operands.size()), + absl::Span(operands).subspan( + operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1046,7 +1189,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } // We need to know the infeed data shape to construct the infeed @@ -1058,41 +1202,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return Error(lexer_.GetLoc(), "infeed must have a non-empty tuple shape"); } - - if (operands.empty()) { - // TODO(b/80000000): Remove this when all uses of infeed are - // converted to take tokens. - instruction = builder->AddInstruction(HloInstruction::CreateInfeed( - ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); - } else if (operands.size() == 1) { - instruction = builder->AddInstruction(HloInstruction::CreateInfeed( - ShapeUtil::GetTupleElementShape(shape, 0), operands[0], - config ? *config : "")); - } else { - return Error(lexer_.GetLoc(), - "infeed must have exactly zero or one operands"); - } + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); break; } case HloOpcode::kOutfeed: { optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } - if (operands.size() == 1) { - // TODO(b/80000000): Remove this when all uses of outfeed are - // converted to take tokens. - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); - } else if (operands.size() == 2) { - instruction = builder->AddInstruction( - HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], - operands[1], config ? *config : "")); - } else { - return Error(lexer_.GetLoc(), - "outfeed must have exactly one or two operands"); - } + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); break; } case HloOpcode::kRng: { @@ -1144,11 +1268,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional custom_call_target; optional window; optional dnums; + optional feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1160,20 +1287,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } - break; - } - case HloOpcode::kHostCompute: { - optional channel_name; - optional cost_estimate_ns; - attrs["channel_name"] = {/*required=*/true, AttrTy::kString, - &channel_name}; - attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, - &cost_estimate_ns}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { - return false; + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); } - instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( - shape, operands, *channel_name, *cost_estimate_ns)); break; } case HloOpcode::kDot: { @@ -1189,6 +1305,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1213,27 +1332,35 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - instruction = builder->AddInstruction( - HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); + PrecisionConfig precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfig::DEFAULT); + } + + instruction = builder->AddInstruction(HloInstruction::CreateDot( + shape, operands[0], operands[1], dnum, precision_config)); break; } case HloOpcode::kGather: { - optional> output_window_dims; - attrs["output_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; - attrs["elided_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; - attrs["gather_dims_to_operand_dims"] = {/*required=*/true, - AttrTy::kBracedInt64List, - &gather_dims_to_operand_dims}; + optional> offset_dims; + attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, + &offset_dims}; + optional> collapsed_slice_dims; + attrs["collapsed_slice_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; + optional> start_index_map; + attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, + &start_index_map}; optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; - attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, - &window_bounds}; + optional> slice_sizes; + attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, + &slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1242,14 +1369,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, GatherDimensionNumbers dim_numbers = HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*offset_dims=*/*offset_dims, + /*collapsed_slice_dims=*/*collapsed_slice_dims, + /*start_index_map=*/*start_index_map, /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( - shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], - dim_numbers, *window_bounds)); + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + dim_numbers, *slice_sizes)); break; } case HloOpcode::kScatter: { @@ -1383,7 +1510,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, bool replicated = false; std::vector devices; std::vector tile_assignment_dimensions; - Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1434,7 +1560,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, break; } case TokKind::kShape: - tile_shape = lexer_.GetShapeVal(); + // TODO(b/112302613): Left here for backward compatibility to ignore the + // removed tile shape data. lexer_.Lex(); break; case TokKind::kRbrace: @@ -1449,19 +1576,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, - "replicated shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "maximal shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { @@ -1469,9 +1589,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error( loc, "non-maximal shardings must have more than one device assigned"); } - if (ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "non-maximal shardings should have a tile shape set"); - } if (tile_assignment_dimensions.empty()) { return Error( loc, @@ -1479,7 +1596,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } @@ -1506,14 +1622,14 @@ bool HloParser::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = MakeUnique( + auto entry_sharding_ptr = absl::make_unique( HloSharding::FromProto(*entry_sharding).ValueOrDie()); - auto exit_sharding_ptr = MakeUnique( + auto exit_sharding_ptr = absl::make_unique( HloSharding::FromProto(*exit_sharding).ValueOrDie()); domain->entry_metadata = - MakeUnique(std::move(entry_sharding_ptr)); + absl::make_unique(std::move(entry_sharding_ptr)); domain->exit_metadata = - MakeUnique(std::move(exit_sharding_ptr)); + absl::make_unique(std::move(exit_sharding_ptr)); } else { return TokenError(StrCat("unsupported domain kind: ", *kind)); } @@ -1533,11 +1649,9 @@ bool HloParser::ParseInstructionNames( if (!ParseName(&name)) { return Error(loc, "expects a instruction name"); } - std::pair* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair* instr = FindInstruction(name); if (!instr) { - return TokenError( - Printf("instruction '%s' is not defined", name.c_str())); + return TokenError(StrFormat("instruction '%s' is not defined", name)); } instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); @@ -1685,8 +1799,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1696,8 +1809,7 @@ bool HloParser::ParseLiteral(std::unique_ptr* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1705,8 +1817,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1732,8 +1843,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1742,8 +1852,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1766,10 +1875,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim( elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - Join(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { - StrAppend(out, num_elems - 1); - }), + StrJoin(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1779,17 +1888,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kLbrace: { nest_level++; if (nest_level > rank) { - return TokenError(Printf( - "expects nested array in rank %lld, but sees larger", rank)); + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees larger", rank)); } if (nest_level > 1) { elems_seen_per_dim[nest_level - 2]++; if (elems_seen_per_dim[nest_level - 2] > shape.dimensions(nest_level - 2)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees more", + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees more", shape.dimensions(nest_level - 2), - get_index_str(nest_level - 2).c_str())); + get_index_str(nest_level - 2))); } } lexer_.Lex(); @@ -1798,9 +1907,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kRbrace: { nest_level--; if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees %lld", - shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees %d", + shape.dimensions(nest_level), get_index_str(nest_level), elems_seen_per_dim[nest_level])); } elems_seen_per_dim[nest_level] = 0; @@ -1808,7 +1917,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, break; } case TokKind::kComma: - case TokKind::kComment: // Skip. lexer_.Lex(); break; @@ -1822,15 +1930,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, if (rank > 0) { if (nest_level != rank) { return TokenError( - Printf("expects nested array in rank %lld, but sees %lld", rank, - nest_level)); + absl::StrFormat("expects nested array in rank %d, but sees %d", + rank, nest_level)); } elems_seen_per_dim[rank - 1]++; if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError( - Printf("expects %lld elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); } } if (lexer_.GetKind() == TokKind::kw_true || @@ -1838,7 +1946,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1849,7 +1957,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1860,7 +1968,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1872,12 +1980,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -1917,13 +2024,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, } template -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = MakeUnique(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -1957,7 +2063,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", Join(index, ", "), "]")); + ": [", StrJoin(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1997,7 +2103,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2005,10 +2111,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } @@ -2018,6 +2124,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, // ::= operand (, operand)* // operand ::= (shape)? name bool HloParser::ParseOperands(std::vector* operands) { + CHECK(operands != nullptr); if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { return false; @@ -2028,9 +2135,10 @@ bool HloParser::ParseOperands(std::vector* operands) { do { LocTy loc = lexer_.GetLoc(); string name; + optional shape; if (CanBeShape()) { - Shape shape; - if (!ParseShape(&shape)) { + shape.emplace(); + if (!ParseShape(&shape.value())) { return false; } } @@ -2038,8 +2146,8 @@ bool HloParser::ParseOperands(std::vector* operands) { return false; } std::pair* instruction = - tensorflow::gtl::FindOrNull(instruction_pool_, name); - if (!instruction) { + FindInstruction(name, shape); + if (instruction == nullptr) { return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction->first); @@ -2050,6 +2158,7 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + CHECK(operands != nullptr); LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; @@ -2083,8 +2192,8 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("sub-attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("sub-attribute %s is expected but not seen", + attr_it.first)); } } return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); @@ -2104,8 +2213,8 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("attribute %s is expected but not seen", + attr_it.first)); } } return true; @@ -2121,7 +2230,7 @@ bool HloParser::ParseAttributeHelper( } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return Error(loc, Printf("attribute %s already exists", name.c_str())); + return Error(loc, StrFormat("attribute %s already exists", name)); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { @@ -2131,13 +2240,13 @@ bool HloParser::ParseAttributeHelper( } else { allowed_attrs = StrCat( "Allowed attributes: ", - Join(attrs, ", ", - [&](string* out, const std::pair& kv) { - StrAppend(out, kv.first); - })); + StrJoin(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); } - return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), - allowed_attrs.c_str())); + return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, + allowed_attrs)); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -2255,6 +2364,26 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kBracedInt64ListList: { + std::vector> result; + auto parse_and_add_item = [&]() { + std::vector item; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &item)) { + return false; + } + result.push_back(item); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + static_cast>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kSliceRanges: { SliceRanges result; if (!ParseSliceRanges(&result)) { @@ -2299,10 +2428,20 @@ bool HloParser::ParseAttributeHelper( case AttrTy::kDomain: { return ParseDomain(static_cast(attr_out_ptr)); } + case AttrTy::kPrecisionList: { + std::vector result; + if (!ParsePrecisionList(&result)) { + return false; + } + static_cast>*>( + attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return Error(loc, Printf("error parsing attribute %s", name.c_str())); + return Error(loc, StrFormat("error parsing attribute %s", name)); } return true; } @@ -2417,20 +2556,24 @@ bool HloParser::ParseConvolutionDimensionNumbers( } string str = lexer_.GetStrVal(); - // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - // So we replace the "->" with "_" and then split on "_". - str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", - /*newsub=*/"_", - /*replace_all=*/false); - std::vector lhs_rhs_out = Split(str, "_"); - if (lhs_rhs_out.size() != 3) { + std::vector split1 = absl::StrSplit(str, "_"); + if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; } + std::vector split2 = absl::StrSplit(split1[1], "->"); + if (split2.size() != 2) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + absl::string_view lhs = split1[0]; + absl::string_view rhs = split2[0]; + absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs_rhs_out[0].length(); - if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + const tensorflow::int64 rank = lhs.length(); + if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); } @@ -2445,8 +2588,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // lhs { - const string& lhs = lhs_rhs_out[0]; - if (!is_unique(lhs)) { + if (!is_unique(string(lhs))) { return TokenError( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } @@ -2463,14 +2605,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_input_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1)); } } } // rhs { - const string& rhs = lhs_rhs_out[1]; - if (!is_unique(rhs)) { + if (!is_unique(string(rhs))) { return TokenError( StrCat("expects unique rhs dimension numbers, but sees ", rhs)); } @@ -2487,14 +2628,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_kernel_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1)); } } } // output { - const string& out = lhs_rhs_out[2]; - if (!is_unique(out)) { + if (!is_unique(string(out))) { return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } @@ -2510,8 +2650,8 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c < '0' + rank && c >= '0') { dnums->set_output_spatial_dimensions(c - '0', i); } else { - return TokenError( - Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + return TokenError(StrFormat( + "expects [0-%dbf] in output dimension numbers", rank - 1)); } } } @@ -2557,9 +2697,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { } const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return Error(loc, Printf("expects [start:limit:step] or [start:limit], " - "but sees %ld elements.", - range.size())); + return Error(loc, + StrFormat("expects [start:limit:step] or [start:limit], " + "but sees %d elements.", + range.size())); } } while (EatIfPresent(TokKind::kComma)); @@ -2571,6 +2712,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } +// precisionlist ::= start precision_elements end +// precision_elements +// ::= /*empty*/ +// ::= precision_val (delim precision_val)* +bool HloParser::ParsePrecisionList( + std::vector* result) { + auto parse_and_add_item = [&]() { + PrecisionConfig::Precision item; + if (!ParsePrecision(&item)) { + return false; + } + result->push_back(item); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2597,6 +2756,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +bool HloParser::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function& parse_and_add_item) { + if (!ParseToken(start, StrCat("expects a list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects a list to end with ", TokKindToString(end))); +} + // param_list_to_shape ::= param_list '->' shape bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { @@ -2707,14 +2886,13 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return Error(loc, - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { tensorflow::int64 number; if (!ParseInt64(&number)) { - return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } result->push_back(number); return true; @@ -2722,9 +2900,8 @@ bool HloParser::ParseDxD(const string& name, // 2D or higher. if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); - if (!SplitAndParseAsInts(str, 'x', result)) { - return Error(loc, - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + if (!SplitToInt64s(str, 'x', result)) { + return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); } lexer_.Lex(); return true; @@ -2742,10 +2919,9 @@ bool HloParser::ParseWindowPad( return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); } string str = lexer_.GetStrVal(); - std::vector padding_str = Split(str, 'x'); - for (int i = 0; i < padding_str.size(); i++) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector low_high; - if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, "expects padding_low and padding_high separated by '_'"); @@ -2766,10 +2942,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { } LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); - std::vector padding_str = Split(str, 'x'); - for (const auto& padding_dim_str : padding_str) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector padding_dim; - if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, "expects padding config pattern like 'low_high_interior' or " @@ -2821,9 +2996,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToHloOpcode(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects opcode but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2837,7 +3011,7 @@ bool HloParser::ParseFftType(FftType* result) { } string val = lexer_.GetStrVal(); if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { - return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + return TokenError(StrFormat("expects fft type but sees: %s", val)); } lexer_.Lex(); return true; @@ -2851,9 +3025,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToFusionKind(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2869,8 +3043,25 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { auto status_or_result = StringToRandomDistribution(val); if (!status_or_result.ok()) { return TokenError( - Printf("expects random distribution but sees: %s, error: %s", - val.c_str(), status_or_result.status().error_message().c_str())); + StrFormat("expects random distribution but sees: %s, error: %s", val, + status_or_result.status().error_message())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { + VLOG(1) << "ParsePrecision"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToPrecision(val); + if (!status_or_result.ok()) { + return TokenError(StrFormat("expects precision but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2964,7 +3155,7 @@ StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); @@ -2976,7 +3167,7 @@ StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after window"); @@ -2989,7 +3180,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; if (!ParseConvolutionDimensionNumbers(&dnums)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument( @@ -2998,40 +3189,113 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { return dnums; } +StatusOr HloParser::ParsePaddingConfigOnly() { + lexer_.Lex(); + PaddingConfig padding_config; + if (!ParsePaddingConfig(&padding_config)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after PaddingConfig"); + } + return padding_config; +} + +Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name) { + TF_RET_CHECK(missing_instruction_hook_ == nullptr); + + // The missing instruction hook we register creates the shaped instruction on + // the fly as a parameter and returns it. + int64 parameter_count = 0; + missing_instruction_hook_ = + [this, builder, ¶meter_count]( + string name, + const optional& shape) -> std::pair* { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + StrCat("Operand ", name, + " had no shape in HLO text; cannot create parameter for " + "single-instruction module.")); + return nullptr; + } + HloInstruction* parameter = builder->AddInstruction( + HloInstruction::CreateParameter(parameter_count++, *shape, name)); + instruction_pool_[name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(instruction_pool_, name); + }; + + // Prime the lexer. + lexer_.Lex(); + + // Parse the instruction with the registered hook. + if (!ParseInstruction(builder, root_name)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + return Status::OK(); +} + } // namespace StatusOr> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config) { - HloParser parser(str, config); - if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); + absl::string_view str, const HloModuleConfig& config) { + auto module = absl::make_unique(/*name=*/"", config); + HloParser parser(str); + if (!parser.Run(module.get())) { + return InvalidArgument("Syntax error:\n%s", parser.GetError()); } - return parser.ConsumeHloModule(); + return std::move(module); } -StatusOr> ParseHloString( - tensorflow::StringPiece str) { - HloModuleConfig config; - return ParseHloString(str, config); +StatusOr> ParseHloString(absl::string_view str) { + auto module = absl::make_unique(/*name=*/"", HloModuleConfig()); + HloParser parser(str); + if (!parser.Run(module.get())) { + return InvalidArgument("Syntax error:\n%s", parser.GetError()); + } + return std::move(module); +} + +Status ParseHloString(absl::string_view str, HloModule* module) { + TF_RET_CHECK(module->computation_count() == 0); + HloParser parser(str); + if (!parser.Run(module)) { + return InvalidArgument("Syntax error:\n%s", parser.GetError()); + } + return Status::OK(); } -StatusOr ParseSharding(tensorflow::StringPiece str) { - HloModuleConfig config; - HloParser parser(str, config); +StatusOr> ParseHloOpToModule( + absl::string_view str, absl::string_view name) { + HloParser parser(str); + auto builder = absl::make_unique(string(name)); + string root_name; + TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); + std::unique_ptr computation = builder->Build(); + auto module = absl::make_unique(string(name), HloModuleConfig()); + module->AddEntryComputation(std::move(computation)); + return std::move(module); +} + +StatusOr ParseSharding(absl::string_view str) { + HloParser parser(str); return parser.ParseShardingOnly(); } -StatusOr ParseWindow(tensorflow::StringPiece str) { - HloModuleConfig config; - HloParser parser(str, config); +StatusOr ParseWindow(absl::string_view str) { + HloParser parser(str); return parser.ParseWindowOnly(); } StatusOr ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str) { - HloModuleConfig config; - HloParser parser(str, config); + absl::string_view str) { + HloParser parser(str); return parser.ParseConvolutionDimensionNumbersOnly(); } +StatusOr ParsePaddingConfig(absl::string_view str) { + HloParser parser(str); + return parser.ParsePaddingConfigOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd..369603551463fd4b4911b393f3c6c2b36f0e4bbb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" @@ -29,30 +30,42 @@ namespace xla { // For details about the syntax accepted by this parser, see // g3doc/hlo_parser.md. -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with the given config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with the given config. StatusOr> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config); + absl::string_view str, const HloModuleConfig& config); -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with default config. -StatusOr> ParseHloString( - tensorflow::StringPiece str); +// Given a string in the HloModule::ToString() format, parses the string and +// builds the HloModule in place at the given module pointer. 'module' must +// point to an empty module (no computations). +Status ParseHloString(absl::string_view str, HloModule* module); + +// Parses the text for a single HLO operation into an HLO module with a function +// that runs that operation (with the same parameters) as its entry computation. +StatusOr> ParseHloOpToModule( + absl::string_view str, absl::string_view name = "single_op"); + +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with default config. +StatusOr> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". -StatusOr ParseSharding(tensorflow::StringPiece str); +StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). -StatusOr ParseWindow(tensorflow::StringPiece str); +StatusOr ParseWindow(absl::string_view str); // Parses the result of ConvolutionDimensionNumbersToString(), e.g. // "b0f_0io->b0f". StatusOr ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str); + absl::string_view str); // ParseHloString sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(tensorflow::StringPiece str); +StatusOr ParseSharding(absl::string_view str); + +// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". +StatusOr ParsePaddingConfig(absl::string_view str); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 7344679bb619b841483dde461b634a38b1490d44..cca50fab5444d5e23c02952d56566b643a2192a4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -16,17 +16,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { - namespace { -using ::tensorflow::StringPiece; +namespace op = ::xla::testing::opcode_matchers; +using absl::string_view; struct TestData { string test_name; @@ -380,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default} } )" @@ -752,10 +756,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] { "gather", R"(HloModule StringifyGather -ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" @@ -1030,8 +1034,8 @@ R"(HloModule gather ENTRY Gather { input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" @@ -1049,7 +1053,7 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add } )" @@ -1067,7 +1071,43 @@ add { ENTRY CrossReplicaSumWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add +} + +)" +}, +// all-to-all +{ +"AllToAll", +R"(HloModule AllToAll + +ENTRY AllToAll { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} +} + +)" +}, +// all-to-all with subgroups +{ +"AllToAllWithSubgroups", +R"(HloModule AllToAllWithSubgroups + +ENTRY AllToAllWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} +} + +)" +}, +// collective-permute +{ +"CollectivePermute", +R"(HloModule CollectivePermute + +ENTRY CollectivePermute { + input = f32[128,32]{0,1} parameter(0) + ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } )" @@ -1078,31 +1118,44 @@ ENTRY CrossReplicaSumWithSubgroups { R"(HloModule iota ENTRY Iota { - ROOT iota = f32[100]{0} iota() + ROOT iota = f32[100]{0} iota(), iota_dimension=0 } )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" +} + +)" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } )" } - }); +}); // clang-format on } class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: - static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -1346,7 +1399,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; @@ -1366,15 +1419,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); ExpectHasSubstr( - ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) + ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); @@ -1536,6 +1588,81 @@ ENTRY consts { "last"); } +TEST_F(HloParserTest, Comments) { + const string original = R"(/* module description. */ +HloModule comments: + +ENTRY /*comment*/ c1 { + /* blah */ + ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/}) + /* comment */ +} + +/* something else */ + +)"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, MultilineComments) { + const string original = R"(HloModule multiline_comment: +ENTRY c1 { + /* + ROOT foo = f32[1]{0} constant({12345}) + */ + ROOT const1 = f32[1]{0} constant({12345}) +/* +a +b +c +d + +*/ +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, UnterminatedComment) { + const string original = R"(HloModule unterminated_comment: +ENTRY c1 { +/* unterminated + ROOT const1 = f32[1]{0} constant({12345}) +})"; + // Verify that the error message points to the beginning of the unterminated + // comment. + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "/* unterminated\n^"); +} + +TEST_F(HloParserTest, SlashSlashComments) { + const string original = R"(HloModule slash_slash_comment: +// Garbage +ENTRY c1 { + // Foo bar + ROOT const1 = f32[1]{0} constant({12345}) // Something else +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) { + const string original = + "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo " + "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) { + const string original = + "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo " + "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + TEST_F(HloParserTest, MultipleEntries) { const string original = R"(HloModule multiple_entries: ENTRY c1 { @@ -1613,6 +1740,25 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) { + const string original = "0_1x2_3"; + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); + EXPECT_EQ(original, PaddingConfigToString(dnums)); +} + +TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) { + const string original = "0_1_0x2_3_4"; + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); + EXPECT_EQ(original, PaddingConfigToString(dnums)); +} + +TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4")); + // The extra "_0" gets added to the canonical string because the other dim has + // interior padding. + EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums)); +} + TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { @@ -1623,5 +1769,128 @@ ENTRY nontuple_infeed { "infeed must have a non-empty tuple shape"); } +TEST(HloParserSingleOpTest, SingleOp) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " + "f32[2,4]{1,0} %x)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { + const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; + StatusOr> module = ParseHloOpToModule(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT( + module.status().ToString(), + ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); +} + +TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { + const string text = + R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Convolution(op::Parameter(0), op::Parameter(1))); + auto* convolution = + Cast(computation->root_instruction()); + EXPECT_EQ(convolution->feature_group_count(), 1); +} + +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 28194deb0e32252b372a328b006dabaf250fa2c7..791b1a97b0b82edf19ff1588fd8d5d996ac0fef4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -45,7 +45,7 @@ class HloPassFix : public Pass { ++iteration_count; if (iteration_count == limit) { LOG(ERROR) - << "Unexpectedly number of iterations in HLO passes (" + << "Unexpectedly high number of iterations in HLO passes (" << iteration_count << ")\nIf compilation hangs here, please file a bug with XLA."; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index 0cddf8fb8f7589739d1233fa4974ff703211a137..fdaac34386c5135d6bbeb372d7a9199344836c8d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -25,15 +26,45 @@ limitations under the License. namespace xla { // Base class for HLO passes. These are used with the HloPassPipeline to -// organize a sequence of passes. +// organize a sequence of passes. An HLO pass should not extend this class +// directly; it should extend HloModulePass or HloModuleGroupPass. class HloPassInterface { public: virtual ~HloPassInterface() = default; - virtual tensorflow::StringPiece name() const = 0; + virtual absl::string_view name() const = 0; - // Run the pass on the given HLO module. Return whether it modified the + // Run the pass on the given HLO module. Returns whether it modified the // module. virtual StatusOr Run(HloModule* module) = 0; + + // Run the pass on the given HLO module group. Returns whether it modified the + // module group. Ideally, the module group variant would be named "Run" as + // well, but C++ does not handle overloaded virtual methods well. + virtual StatusOr RunOnModuleGroup(HloModuleGroup* module_group) = 0; +}; + +// Base class for passes which are module-scoped. +class HloModulePass : public HloPassInterface { + public: + // Runs the pass on a module group by iterating through each module in the + // group. + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + TF_ASSIGN_OR_RETURN(bool module_changed, Run(module)); + changed |= module_changed; + } + return changed; + }; +}; + +// Base class for passes which are module-group scoped. These passes cannot run +// on an HLO module. +class HloModuleGroupPass : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override { + return InternalError("Module group pass cannot be run on a module"); + } }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index d8f1ab916b5c5c500c2d8dcd8605be083f95862a..8c2f928ca101fae8e63663705554ae626c863bf6 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,119 +17,139 @@ limitations under the License. #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { -namespace { -void DumpModuleGraph(const HloModule& module, const string& message) { - hlo_graph_dumper::MaybeDumpHloModule(module, message); - VLOG(3) << "HLO " << message << ":"; - XLA_VLOG_LINES(3, module.ToString()); +template +Status HloPassPipeline::RunInvariantCheckers( + HloT* hlo, absl::string_view after_pass_name) { + for (auto& invariant_checker : invariant_checkers_) { + VLOG(1) << " Invariant checker " << invariant_checker->name(); + StatusOr changed_status = RunHelper(invariant_checker.get(), hlo); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); + if (!changed_status.ok()) { + VLOG(2) << "Failed invariant check:"; + XLA_VLOG_LINES(2, hlo->ToString()); + return Status(changed_status.status().code(), + absl::StrCat(changed_status.status().error_message(), + "\n\nFailed after ", after_pass_name)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; + } + return Status::OK(); } -void DumpModuleProto(const HloModule& module, const string& dump_to, - const string& pipeline_name, const string& pass_name) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new tensorflow::gtl::FlatMap(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; +template +StatusOr HloPassPipeline::RunPassesInternal( + HloT* hlo, absl::Span passes) { + string last_pass_name = "pipeline-start"; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); + bool changed = false; + for (HloPassInterface* pass : passes) { + VLOG(1) << " HLO pass " << pass->name(); + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/pass->name()); + TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); + changed |= pass_changed; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name())); + last_pass_name = string(pass->name()); + } + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/"pipeline-end"); + return changed; +} - const string mod_name = SanitizeFileName(tensorflow::strings::Printf( - "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, - pipeline_name.c_str(), pass_name.c_str())); +std::vector HloPassPipeline::GetEnabledPasses( + const DebugOptions& debug_options) { + auto repeated_field = debug_options.xla_disable_hlo_passes(); + tensorflow::gtl::FlatSet disabled_pass_names(repeated_field.begin(), + repeated_field.end()); + if (!disabled_pass_names.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << absl::StrJoin(disabled_pass_names, ", "); + } - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), - dump_to, mod_name)); + std::vector enabled_passes; + for (auto& pass : passes_) { + if (disabled_pass_names.count(string(pass->name())) == 0) { + enabled_passes.push_back(pass.get()); + } + } + return enabled_passes; } -} // namespace -StatusOr HloPassPipeline::Run(HloModule* module) { - run_called_ = true; +void HloPassPipeline::MaybeDumpHlo(const HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + const string& proto_dump_path = + module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!proto_dump_path.empty()) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new tensorflow::gtl::FlatMap(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string filename = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, name(), after_pass_name)); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( + MakeHloProto(module), proto_dump_path, filename)); + } - VLOG(1) << "Running HLO pass pipeline " << name(); + const string message = + StrCat("after ", after_pass_name, ", before ", before_pass_name); + hlo_graph_dumper::MaybeDumpHloModule(module, message); + VLOG(3) << "HLO " << message << ":"; + XLA_VLOG_LINES(3, module.ToString()); +} - auto repeated_field = - module->config().debug_options().xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), - repeated_field.end()); - if (!disabled_passes.empty()) { - VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << tensorflow::str_util::Join(disabled_passes, ", "); +void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (const HloModule* module : module_group.modules()) { + MaybeDumpHlo(*module, after_pass_name, before_pass_name); } +} - auto run_invariant_checkers = [this, - module](const string& message) -> Status { - for (auto& invariant_checker : invariant_checkers_) { - VLOG(1) << " Invariant checker " << invariant_checker->name(); - StatusOr changed_status = invariant_checker->Run(module); - VLOG(1) << " Invariant checker done " << invariant_checker->name(); - if (!changed_status.ok()) { - VLOG(2) << "Module failed invariant check:"; - XLA_VLOG_LINES(2, module->ToString()); - return Status(changed_status.status().code(), - StrCat(changed_status.status().error_message(), - "\n\nFailed ", message)); - } - TF_RET_CHECK(!changed_status.ValueOrDie()) - << "invariant checkers must not change the graph"; - } - return Status::OK(); - }; +StatusOr HloPassPipeline::Run(HloModule* module) { + run_called_ = true; - string prefix = std::string(name()) + ": pipeline start"; - bool changed = false; - string message; - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("before running pipeline: ", name()))); - const string xla_dump_per_pass_hlo_proto_to = - module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), "pipeline_start"); - } + VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " + << name(); - for (auto& pass : passes_) { - if (disabled_passes.count(std::string(pass->name())) > 0) { - VLOG(1) << " Skipping HLO pass " << pass->name() - << ", disabled by --xla_disable_hlo_passes"; - continue; - } + return RunPassesInternal(module, + GetEnabledPasses(module->config().debug_options())); +} - VLOG(1) << " HLO pass " << pass->name(); +StatusOr HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) { + run_called_ = true; - // Emit label containing: "after foo-pass, before bar-pass". - message.clear(); - StrAppend(&message, prefix, ", before ", pass->name()); - DumpModuleGraph(*module, message); - - TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("after running pass: ", pass->name()))); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), std::string(pass->name())); - } + VLOG(1) << "Running HLO pass pipeline on module group " + << module_group->name() << ": " << name(); - changed |= changed_this_pass; - prefix.clear(); - StrAppend(&prefix, name(), ": after ", pass->name()); + if (module_group->modules().empty()) { + VLOG(1) << "Module group is empty. Nothing to do."; + return false; } - DumpModuleGraph(*module, prefix + ", pipeline end"); - return changed; + + return RunPassesInternal( + module_group, + GetEnabledPasses(module_group->module(0).config().debug_options())); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a42d7e59fed2d838dfe3cb7f99e6b946edfdb0b4..09e7033ea4ed88849d2f3665d04f74f3f388b3f5 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -21,7 +21,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +35,7 @@ namespace xla { class HloPassPipeline : public HloPassInterface { public: explicit HloPassPipeline(const string& name) : name_(name) {} - tensorflow::StringPiece name() const override { return name_; } + absl::string_view name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the // pass constructor: @@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface { return *pass; } - // Run all passes on the given HLO module. StatusOr Run(HloModule* module) override; + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override; private: + // Returns the set of passes which are enabled. DebugOptions can selectively + // disable passes via --xla_disable_hlo_passes flag. + std::vector GetEnabledPasses( + const DebugOptions& debug_options); + + // Maybe dumps the given module or module group depending on flag values + // contained in DebugOptions of module config. + void MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, + absl::string_view before_pass_name); + + // Runs the invariant checker on the given HLO. HloT can be either HloModule + // or HloModuleGroup. + template + Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name); + + // Helper which runs the given pass on the given HLO. HloT can be either + // HloModule or HloModuleGroup. + template + StatusOr RunPassesInternal(HloT* hlo, + absl::Span passes); + + // Helpers which run the given passes on the given HLO construct. These + // helpers enable templating of the core of the pipeline logic by providing + // HloModule and HloModuleGroup specific methods with the same name. + static StatusOr RunHelper(HloPassInterface* pass, HloModule* module) { + return pass->Run(module); + } + static StatusOr RunHelper(HloPassInterface* pass, + HloModuleGroup* module_group) { + return pass->RunOnModuleGroup(module_group); + } + const string name_; std::vector> passes_; std::vector> invariant_checkers_; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee8cb12b231718e09f6ac0d05d7a6887f4c4d746 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloPassPipelineTest : public HloVerifiedTestBase { + protected: + StatusOr ParseModuleGroup( + absl::Span hlo_strings) { + HloModuleGroup group(TestName()); + for (const string& hlo_string : hlo_strings) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + group.push_back(std::move(module)); + } + return std::move(group); + } +}; + +// A module pass which renames instructions named 'foo' to 'bar'. +class FooToBarModulePass : public HloModulePass { + absl::string_view name() const override { return "foo2bar"; } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "foo") { + instruction->SetAndSanitizeName("bar"); + changed = true; + } + } + } + return changed; + } +}; + +// A module group pass which renames instructions named 'baz' to 'qux'. +class BazToQuxModuleGroupPass : public HloModuleGroupPass { + absl::string_view name() const override { return "baz2qux"; } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "baz") { + instruction->SetAndSanitizeName("qux"); + changed = true; + } + } + } + } + return changed; + } +}; + +// An invariant checker pass which returns an error if there exists an +// instruction named 'bar'. +class BarBlowerUpper : public HloModulePass { + absl::string_view name() const override { return "bar-blower-upper"; } + + StatusOr Run(HloModule* module) override { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "bar") { + return InternalError("Module has instruction named bar"); + } + } + } + return false; + } +}; + +TEST_F(HloPassPipelineTest, ModulePassChanged) { + // Test an HLO module pass which changes a module. + const string module_str = R"( +HloModule ModulePassChanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "foo"); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_EQ(root->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, ModulePassUnchanged) { + // Test an HLO module pass which does not change a module. + const string module_str = R"( +HloModule ModulePassUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT blahblah = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(HloPassPipelineTest, MixedPipeline) { + // Test a pipeline with both a module pass and a module group pass. + const string module_0_str = R"( +HloModule MixedPipeline.1 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT baz = f32[] multiply(a, b) +} +)"; + const string module_1_str = R"( +HloModule MixedPipeline.0 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group, + ParseModuleGroup({module_0_str, module_1_str})); + + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + pipeline.AddPass(); + + HloInstruction* root0 = + module_group.module(0).entry_computation()->root_instruction(); + HloInstruction* root1 = + module_group.module(1).entry_computation()->root_instruction(); + EXPECT_EQ(root0->name(), "baz"); + EXPECT_EQ(root1->name(), "foo"); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pipeline.RunOnModuleGroup(&module_group)); + EXPECT_TRUE(changed); + + EXPECT_EQ(root0->name(), "qux"); + EXPECT_EQ(root1->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, InvariantChecker) { + const string module_str = R"( +HloModule InvariantChecker + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + { + // Run a pipeline with just the invariant checker. It should not fail + // because there is no 'bar' instruction in the module. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); + } + + { + // Run a pipeline which renames 'foo' to 'bar' then an invariant checker + // which fails if there is an instruction named 'bar'. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after foo2bar")); + } + + { + // Run the invariant-checker only pipeline again. It should fail this time. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after pipeline-start")); + } +} + +TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) { + // Running a module group pass on a module should produce an error. + const string module_str = R"( +HloModule ModuleGroupPassOnModule + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Module group pass cannot be run on a module")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558d185d1e022660d9a1d23176d0d96bf..b9c0b0c4ee1957fce48641230cef6391bcc9180e 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -23,11 +23,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc index b9cca138703c8fa61aadf69dd7304a215a9f4be2..c3cacd7ce6b1ea3ad7cf84e898f274ae12622ac5 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 01b088a957554821e65db7bf9cedf334db49728f..961930f0a888e90f86e4354fa1373a303af8ec2f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions) + absl::Span instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { @@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap( } bool HloReachabilityMap::SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; @@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion( } void HloReachabilityMap::FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); } void HloReachabilityMap::SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 48215d32a8284919cce6beb1663e6a723eefc1c4..b66a2aa4bd2b00a88cdbfa6b41c9123bb370aa87 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ class HloReachabilityMap { // Sets up a graph with no edges and where the nodes correspond to the given // instructions. explicit HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where @@ -54,13 +54,12 @@ class HloReachabilityMap { // vector in the internal graph of this HloReachabilityMap for the given // instruction and does not transitively update any other part of the // adjacency matrix. - bool SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, - const HloInstruction* instruction); + bool SetReachabilityToUnion(absl::Span inputs, + const HloInstruction* instruction); // As above, but faster because it does not check if the reachability changed. void FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true @@ -141,7 +140,7 @@ class HloReachabilityMap { // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. void SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); // Return the index of the given instruction. The value is used to index into diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 585c95972b0e01abc14543205af71b4b0c0bdf3c..d9848cee0bfa904a90aea4626c3ee62c2cbb45b6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloVerifiedTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index cf0be30c7ad5cbeb7fd3d71c7c649b6b448360b8..a43867193628d05ad7703a5d5ed8bdc9c72de581 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,34 +20,33 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Potential optimizations: // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue // of candidates. @@ -88,7 +87,7 @@ bool CanBeRematerialized( // Type holding a unique identifier for each Buffer object. using BufferId = int64; -using BufferIdList = tensorflow::gtl::InlinedVector; +using BufferIdList = absl::InlinedVector; // We wrap HloInstruction* with an Item that holds auxiliary // per-instruction state. @@ -123,7 +122,7 @@ struct Item { int64 position; }; -using ItemList = tensorflow::gtl::InlinedVector; +using ItemList = absl::InlinedVector; // Class which maintains an ordered list of instructions with fast insertion // before arbitrary elements. @@ -202,15 +201,14 @@ class InstructionList { // On object construction this ordinal is precisely the instruction's index // in the list. Later, instructions inserted via InsertBefore receive // duplicate values. However, monotonicity is preserved. - void InsertBeforeInstructions( - Item* to_insert, tensorflow::gtl::ArraySlice before_instructions) { + void InsertBeforeInstructions(Item* to_insert, + absl::Span before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" - << tensorflow::str_util::Join(before_instructions, ", ", - [](string* out, Item* item) { - tensorflow::strings::StrAppend( - out, item->instruction->name()); - }) + << absl::StrJoin(before_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) << "}"; // Find the minimal position number of any instruction in @@ -393,10 +391,9 @@ class MemoryUsageTracker { int64 unfinished_user_count; string ToString() const { - return tensorflow::strings::StrCat( - "Buffer ", id, " (defined by ", - defining_instruction->instruction->name(), ", size ", size, - " bytes)"); + return absl::StrCat("Buffer ", id, " (defined by ", + defining_instruction->instruction->name(), ", size ", + size, " bytes)"); } }; @@ -740,29 +737,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, } string MemoryUsageTracker::ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend( - &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", - memory_usage(), " bytes)"); + string output = + absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n"); + absl::StrAppend(&output, + "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); for (auto* item = instruction_list_.first(); item != nullptr; item = instruction_list_.next(item)) { const HloInstruction* instruction = item->instruction; string inprogress = item == in_progress_item_ ? " in-progress" : ""; string placed = item->placed ? " placed" : ""; - tensorflow::strings::StrAppend(&output, " ", instruction->name(), - inprogress, placed, "\n Defines:\n"); + absl::StrAppend(&output, " ", instruction->name(), inprogress, placed, + "\n Defines:\n"); for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_[buffer_id]; string live = IsCurrentlyLive(buffer_id) ? " live" : ""; - tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, - ", ", buffer.unfinished_user_count, - " unfinished uses\n"); + absl::StrAppend(&output, " ", buffer.ToString(), live, ", ", + buffer.unfinished_user_count, " unfinished uses\n"); } - tensorflow::strings::StrAppend(&output, " Uses:\n"); + absl::StrAppend(&output, " Uses:\n"); for (BufferId buffer_id : item->buffers_used) { - tensorflow::strings::StrAppend(&output, " ", - buffers_[buffer_id].ToString(), "\n"); + absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); } } return output; @@ -780,10 +775,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( defined_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); for (const Buffer& buffer : buffers_) { @@ -803,10 +797,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( used_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); } for (const Buffer& buffer : buffers_) { @@ -968,8 +961,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( } StatusOr HloRematerialization::RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, + HloComputation* computation, HloSchedule* schedule, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -977,7 +969,8 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list(sequence->at(computation)); + InstructionList instruction_list( + schedule->sequence(computation).instructions()); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1151,7 +1144,7 @@ StatusOr HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, + RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1185,12 +1178,12 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - auto& dst = sequence->at(computation); - dst.clear(); + HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); + sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - dst.push_back(instruction); + sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1200,16 +1193,18 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run( - HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The sequence is constructed entirely by this method. - TF_RET_CHECK(sequence->empty()); - +StatusOr HloRematerialization::Run(HloModule* module) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes); + << HumanReadableNumBytes(memory_limit_bytes_); + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + // Initialize pass object state. + computation_peak_memory_.clear(); + rematerialized_computations_.clear(); + instructions_rematerialized_ = 0; + net_instructions_added_ = 0; + + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1225,39 +1220,23 @@ StatusOr HloRematerialization::Run( }); const int64 adjusted_memory_limit_bytes = - memory_limit_bytes - module_output_size; + memory_limit_bytes_ - module_output_size; VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); - XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - } - // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, sequence](const CallGraphNode& node) -> Status { + [this, module](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], ComputePeakMemory(node.computation(), - sequence->at(node.computation()))); + module->schedule() + .sequence(node.computation()) + .instructions())); } return Status::OK(); }, @@ -1275,9 +1254,10 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), sequence, - adjusted_memory_limit_bytes)); + TF_ASSIGN_OR_RETURN( + bool changed, + RematerializeComputation(module->entry_computation(), &module->schedule(), + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1286,30 +1266,7 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto* computation : module->MakeNonfusionComputations()) { - if (sequence->at(computation).size() != computation->instruction_count()) { - // A size mismatch between the computation instruction count and the size - // of the ordering of instructions can only be caused by DCE. Rebuild the - // order by removing the deleted instructions from the order. - tensorflow::gtl::FlatSet instruction_set; - for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction); - } - // Move the old order into a temporary vector, then build new order - // inplace. - std::vector& order = sequence->at(computation); - std::vector old_order; - using std::swap; - swap(order, old_order); - std::copy_if(old_order.begin(), old_order.end(), - std::back_inserter(order), - [&instruction_set](const HloInstruction* instruction) { - return ContainsKey(instruction_set, instruction); - }); - TF_RET_CHECK(sequence->at(computation).size() == - computation->instruction_count()); - } - } + TF_RETURN_IF_ERROR(module->schedule().Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1326,34 +1283,22 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes != nullptr) { - sizes->before_bytes = before_peak_memory; - sizes->after_bytes = current_peak_memory; + if (sizes_ != nullptr) { + sizes_->before_bytes = before_peak_memory; + sizes_->after_bytes = current_peak_memory; } XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << tensorflow::strings::Printf( - "Can't reduce memory use below %s (%lld bytes) by rematerialization; " - "only reduced to %s (%lld bytes)", - HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, - HumanReadableNumBytes(current_peak_memory).c_str(), - current_peak_memory); + if (current_peak_memory > memory_limit_bytes_) { + LOG(WARNING) << absl::StrFormat( + "Can't reduce memory use below %s (%d bytes) by rematerialization; " + "only reduced to %s (%d bytes)", + HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, + HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; } -/* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const HloRematerialization::ShapeSizeFunction& size_function, - int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, - copy_insertion); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ec004350ad88ff31ece90ec419d90a55b965166..7330d73c09eb5aa8265fa5753a2de5885f51bf15 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,16 +17,23 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { -class HloRematerialization { +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function; @@ -37,10 +44,7 @@ class HloRematerialization { int64 after_bytes; }; - // Rematerialize HLO instructions in the given module to reduce peak memory - // use below memory_limit_bytes where memory use is defined as the total size - // of all live HLO instruction values. Parameters and constants are included - // in memory use estimates. Method parameters: + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. @@ -48,60 +52,34 @@ class HloRematerialization { // memory_limit_bytes: The threshold number of bytes to reduce memory use to // via rematerialization. // - // hlo_module: HLO module to rematerialize instructions in. - // - // sequence: Should point to an empty HloModuleSequence. Upon return - // contains the HLO instruction order which was used for - // rematerialization. This is the order in which HLO instructions should - // be emitted to minimize memory use. - // - // sizes: Optional outparam that indicates the peak memory usage of the HLO - // module before/after rematerialization. - // - // copy_insertion: If non-null, run copy elision after scheduling. This - // pass is used to eliminate copies that were inserted by copy insertion - // before HLO scheduling. - // - // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy - // insertion is integrated with HLO scheduling. - // - // Returns whether any instructions were rematerialized. If memory use is - // already below the given limit then no instructions are rematerialized and - // false is returned. - // - // CSE will undo the effects of this optimization and should not be run after - // this pass. In general, this pass should be run very late immediately before - // code generation. - static StatusOr RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); - - protected: - HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) - : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + // sizes: Pointer to data structure which records the peak memory usage of + // the HLO module before/after rematerialization. Value are set during + // Run(). Can be nullptr. + HloRematerialization(const ShapeSizeFunction& size_function, + int64 memory_limit_bytes, RematerializationSizes* sizes) + : size_function_(size_function), + memory_limit_bytes_(memory_limit_bytes), + sizes_(sizes) {} ~HloRematerialization() {} + absl::string_view name() const override { return "rematerialization"; } + // Runs rematerialization on the given module. Returns whether the module was - // changed. memory_limit is the target maximum peak memory usage by the - // module. sequence should be an empty HloModuleSequence. Upon return sequence - // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit, RematerializationSizes* sizes, - CopyInsertion* copy_insertion); + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + StatusOr Run(HloModule* module) override; + protected: // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 computation_memory_limit); + StatusOr RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64 memory_limit_bytes); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at @@ -122,6 +100,14 @@ class HloRematerialization { // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; + // The threshold number of bytes to reduce memory use to via + // rematerialization. + const int64 memory_limit_bytes_; + + // Pointer to data structure which records the peak memory usage of the HLO + // module before/after rematerialization + RematerializationSizes* sizes_; + // Call graph of the hlo_module. std::unique_ptr call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index ac8c97d380953764b66135ad1c5fcee0d481c004..f7e82fb1f88e856305f6f481a451d4cd64ba4acf 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -141,13 +141,16 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr RunHloRematerialization( - int64 memory_limit_bytes, HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence) { + StatusOr RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,9 +189,13 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -203,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -242,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -276,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -316,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -382,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &sequence)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -571,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index b2725e2918ce76248d9f2cdbb2a6e5a63226bf9a..fa7f216321988137dcf9104a324f5f7789869aa5 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,7 +32,7 @@ limitations under the License. namespace xla { /*static*/ StatusOr> -HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, +HloRunner::CreateModuleFromString(const absl::string_view hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); @@ -106,7 +106,7 @@ StatusOr HloRunner::TransferLiteralToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals) { + const absl::Span literals) { std::vector buffers; for (const Literal* literal : literals) { CHECK(literal != nullptr); @@ -118,16 +118,16 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals) { + const absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { - literal_pointers.push_back(literal.get()); + literal_pointers.push_back(&literal); } return TransferLiteralsToDevice(literal_pointers); } -StatusOr> HloRunner::TransferLiteralFromDevice( +StatusOr HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -135,10 +135,10 @@ StatusOr> HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr> HloRunner::Execute( +StatusOr HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -150,15 +150,15 @@ StatusOr> HloRunner::Execute( return TransferLiteralFromDevice(result); } -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes, ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { - argument_pointers.push_back(argument.get()); + argument_pointers.push_back(&argument); } return Execute( /*module=*/std::move(module), @@ -169,8 +169,8 @@ StatusOr> HloRunner::Execute( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); stream.Init(); @@ -190,8 +190,8 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { @@ -204,7 +204,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } -StatusOr>> HloRunner::ExecuteReplicated( +StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -226,14 +226,13 @@ StatusOr>> HloRunner::ExecuteReplicated( // no arguments. std::vector argument_buffer_ptrs( options.num_replicas * options.arguments.size() + 1); - std::vector> - argument_buffer_slices; + std::vector> argument_buffer_slices; int64 index = 0; for (int64 i = 0; i < options.num_replicas; ++i) { int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(MakeUnique(executor)); + streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -260,7 +259,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = MakeUnique( + pool = absl::make_unique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -291,9 +290,9 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = MakeUnique(); + Literal literal; TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, literal.get())); + executor, options.outfeed_shape, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } @@ -311,10 +310,10 @@ StatusOr>> HloRunner::ExecuteReplicated( argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; - std::vector> exec_results; + std::vector exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 65537f07f56e74b7fe2c2f9792af21efc7229573..2e934bf66ae43ea412f242030b874dddb6d3722d 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -72,7 +72,7 @@ class HloRunner { // A pointer to a vector where the outfeed values will be stored. If // nullptr, the values will be read and discarded. - std::vector>* outfeed_values = nullptr; + std::vector* outfeed_values = nullptr; // Whether the HLO passes should be run on the input module. Usually // saved modules are coming from after the HLO pass pipeline, so triggering @@ -87,8 +87,7 @@ class HloRunner { // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). static StatusOr> CreateModuleFromString( - const tensorflow::StringPiece hlo_string, - const DebugOptions& debug_options); + const absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. @@ -105,43 +104,42 @@ class HloRunner { // Transfers data between the host and device. StatusOr TransferLiteralToDevice(const Literal& literal); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals); + const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals); - StatusOr> TransferLiteralFromDevice( - const ShapedBuffer& buffer); + const absl::Span literals); + StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - StatusOr> Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - StatusOr> Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr>> ExecuteReplicated( + StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options); diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc new file mode 100644 index 0000000000000000000000000000000000000000..3fc5dbeb02a26134a7f255fa0b6ebda1dc41ce4d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -0,0 +1,343 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { + +/* static */ StatusOr HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + tensorflow::gtl::FlatMap id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + +void HloSchedule::set_sequence( + const HloComputation* computation, + absl::Span sequence) { + set_sequence(computation, HloInstructionSequence(sequence)); +} + +void HloSchedule::set_sequence(const HloComputation* computation, + HloInstructionSequence sequence) { + CHECK(computation->parent() == module_); + sequences_[computation->unique_id()] = std::move(sequence); +} + +HloInstructionSequence& HloSchedule::GetOrCreateSequence( + const HloComputation* computation) { + auto it = sequences_.find(computation->unique_id()); + if (it == sequences_.end()) { + // No sequence found for computation. Create and return an empty one. + CHECK(computation->parent() == module_); + return sequences_[computation->unique_id()]; + } else { + return it->second; + } +} + +const HloInstructionSequence& HloSchedule::sequence( + const HloComputation* computation) const { + return sequences_.at(computation->unique_id()); +} + +Status HloSchedule::UpdateComputationSchedule( + const HloComputation* computation) { + // Map from unique ID to HloInstruction pointer for instructions in the + // computation. + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); + } + + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + for (int id : sequences_.at(computation->unique_id()).ids()) { + InsertOrDie(&ids_in_schedule, id); + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // computation, but not in schedule) which use X. If an instruction is not in + // the map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + if (instruction->operands().empty()) { + worklist.push(instruction); + } else { + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + HloInstructionSequence new_sequence; + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + new_sequence.push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : sequences_.at(computation->unique_id()).ids()) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. Do not add + // it to the new schedule. + continue; + } + worklist.push(it->second); + schedule_worklist(); + } + + set_sequence(computation, std::move(new_sequence)); + return Status::OK(); +} + +Status HloSchedule::Update() { + // The schedule must contain a sequence for every non-fusion computation in + // the module, but can have sequences for computations which no longer exist + // (these are removed). + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() << " not in HloSchedule."; + } + if (sequences_.size() > nonfusion_computations.size()) { + // Schedule contains some computations which have been removed from the + // HloModule. Remove them from the schedule as well. + tensorflow::gtl::FlatSet nonfusion_computations_ids; + for (const HloComputation* computation : nonfusion_computations) { + nonfusion_computations_ids.insert(computation->unique_id()); + } + for (auto it = sequences_.begin(); it != sequences_.end();) { + if (nonfusion_computations_ids.count(it->first) == 0) { + it = sequences_.erase(it); + } else { + it++; + } + } + } + CHECK_EQ(sequences_.size(), nonfusion_computations.size()); + + for (const HloComputation* computation : nonfusion_computations) { + TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); + } + + TF_RETURN_IF_ERROR(Verify()); + return Status::OK(); +} + +Status HloSchedule::Verify() const { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(3, module_->ToString()); + XLA_VLOG_LINES(2, ToString()); + + // Verify schedule contains exactly the same set of non-fusion computations as + // module currently does. + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) + << "Schedule has " << sequences_.size() << " sequences, but module has " + << nonfusion_computations.size() << " non-fusion computations"; + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() + << " missing from HLO schedule."; + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : + sequence(computation).instructions()) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + +namespace { + +// Returns the computation in the given module with the given unique ID. Returns +// nullptr if no such computation exists. +const HloComputation* IdToComputation(const HloModule* module, int64 id) { + for (const HloComputation* computation : module->computations()) { + if (computation->unique_id() == id) { + return computation; + } + } + return nullptr; +} + +} // namespace + +string HloSchedule::ToString() const { + std::vector pieces; + + pieces.push_back("HloSchedule"); + for (const auto& id_sequence : sequences_) { + const HloComputation* computation = + IdToComputation(module_, id_sequence.first); + if (computation == nullptr) { + // The computation is not in the module and may have been deleted so it is + // not safe to dereference any HLO pointers. Just use the HLO unique ids + // stored in this object. + pieces.push_back( + absl::StrFormat("computation with id %d (no longer in HLO module):", + id_sequence.first)); + for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrCat(" ", id)); + } + } else { + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); + for (const HloInstruction* instruction : + id_sequence.second.instructions()) { + pieces.push_back(absl::StrCat(" ", instruction->name())); + } + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { + out << schedule.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000000000000000000000000000000000..270fe6039f0afd119c76086de9a0596e0560e93e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +class HloModule; + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + explicit HloInstructionSequence( + absl::Span instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr ToProto() const; + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const tensorflow::gtl::FlatMap& sequences() + const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + tensorflow::gtl::FlatMap sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1424569ac1f62e4b965876141f1eb40be4f15bea --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloScheduleTest : public HloTestBase {}; + +TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + const std::vector& entry_schedule = + schedule.sequence(module->entry_computation()).instructions(); + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(entry_schedule, + schedule.sequence(module->entry_computation()).instructions()); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); + }; + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 4); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 3); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 2); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(body).size(), 7); + EXPECT_EQ(schedule.sequence(cond).size(), 4); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(body).size(), 1); + EXPECT_EQ(schedule.sequence(cond).size(), 5); +} + +TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { + // Remove computations from a module and verify the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + HloInstruction* xla_while = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* init = xla_while->mutable_operand(0); + + // Replace the while with its init value. The conditional and body + // computations should then be dead. + TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); + + // DCE the dead code in the body. + HloDCE dce; + ASSERT_EQ(module->computation_count(), 3); + TF_ASSERT_OK(dce.Run(module.get()).status()); + ASSERT_EQ(module->computation_count(), 1); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 6399f6ef3c56383b9357f5f280e4a123dadca693..de7e6b53d4d2aa88e2213248370b4da82bdeadeb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; +using absl::StrCat; +using absl::StrJoin; HloSharding HloSharding::AssignDevice(int64 device_id) { return HloSharding(device_id); @@ -31,12 +32,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { CHECK_EQ(1, ShapeUtil::Rank(input_shape)); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); - Shape tile_shape = input_shape; - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); Array assignment(dimensions); std::iota(assignment.begin(), assignment.end(), 0); - return HloSharding(tile_shape, assignment); + return HloSharding(assignment); } HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { @@ -56,9 +54,8 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { return HloSharding(flattened_list); } -HloSharding HloSharding::Tuple( - const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { +HloSharding HloSharding::Tuple(const Shape& tuple_shape, + absl::Span shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); @@ -74,12 +71,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); - int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + int64 leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; - flattened_list.reserve(leaf_count); - for (int64 i = 0; i < leaf_count; ++i) { - flattened_list.push_back(sharding); - } + flattened_list.resize(leaf_count, sharding); return HloSharding(flattened_list); } @@ -95,7 +89,7 @@ string HloSharding::ToString() const { for (const HloSharding& element : tuple_elements_) { parts.push_back(element.ToString()); } - return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + return StrCat("{", absl::StrJoin(parts, ", "), "}"); } if (replicated_) { @@ -104,9 +98,8 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", - Join(tile_assignment_.dimensions(), ","), "]", - Join(tile_assignment_, ","), "}"); + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), + "]", StrJoin(tile_assignment_, ","), "}"); } } @@ -145,11 +138,10 @@ std::map HloSharding::UsedDevices(int64* count) const { } std::vector HloSharding::TileIndexForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; - tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { + tile_assignment_.Each([&](absl::Span index, int64 d) { if (d == device) { ret_index = {index.begin(), index.end()}; } @@ -158,39 +150,49 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { return ret_index; } -int64 HloSharding::DeviceForTileIndex( - tensorflow::gtl::ArraySlice index) const { +int64 HloSharding::DeviceForTileIndex(absl::Span index) const { CHECK(!replicated_); CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.begin(); } - CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); return tile_assignment_(index); } -std::vector HloSharding::TileOffsetForDevice(int64 device) const { +std::vector HloSharding::TileOffsetForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - std::vector index = TileIndexForDevice(device); if (maximal_) { - // Index will always be all zeroes if we're maximal, and tile_shape_ is not - // valid. - return index; + return std::vector(shape.dimensions_size(), 0); } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] *= tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); } return index; } -std::vector HloSharding::TileLimitForDevice(int64 device) const { +std::vector HloSharding::TileLimitForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + if (maximal_) { + return std::vector(shape.dimensions().begin(), + shape.dimensions().end()); + } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); } return index; } @@ -238,16 +240,16 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree(shape, *this)); } -tensorflow::gtl::optional HloSharding::UniqueDevice() const { +absl::optional HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } - tensorflow::gtl::optional unique_device; + absl::optional unique_device; for (auto& tuple_sharding : tuple_elements_) { auto device = tuple_sharding.UniqueDevice(); if (!device || (unique_device && *device != *unique_device)) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } unique_device = device; } @@ -256,7 +258,7 @@ tensorflow::gtl::optional HloSharding::UniqueDevice() const { if (!replicated_ && maximal_) { return static_cast(*tile_assignment_.begin()); } - return tensorflow::gtl::nullopt; + return absl::nullopt; } int64 HloSharding::GetUniqueDevice() const { @@ -315,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, int32 core) { + [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { @@ -336,11 +338,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile rank must be the same as the input rank. - if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + // The tile assignment tensor must have the same rank as the input. + if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank. sharding=", ToString(), - ", input_shape=", ShapeUtil::HumanString(shape)); + "Number of tile assignment dimensions is different to the input rank. " + "sharding=", + ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } // The correct constructor have to be used to create tile maximal shardings. @@ -350,20 +353,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, "sharding was intended, use HloSharding::Replicated(). If a device " "placement was intended, use HloSharding::AssignDevice()"); } - - // The tile assignment tensor must contain enough element to cover the full - // shape with tiles of the specified size. - for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); - if (shape.dimensions(i) > total_tile_size) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has too few element to cover the full " - "shape. Dimension ", - i, ", shape ", shape.dimensions(i), ", total size ", - total_tile_size)); - } - } - return Status::OK(); } @@ -393,7 +382,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(proto.tile_shape(), tile_assignment); + return HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -407,7 +396,6 @@ OpSharding HloSharding::ToProto() const { return result; } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -424,58 +412,54 @@ OpSharding HloSharding::ToProto() const { return result; } -HloSharding HloSharding::TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform) const { - CHECK(!IsTuple()); +Shape HloSharding::TileShape(const Shape& shape) const { if (IsTileMaximal()) { - return *this; + return shape; } - CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); - Shape new_tile_shape; - new_tile_shape.set_element_type(tile_shape().element_type()); - for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { - int64 dim; - if (tile_assignment().dim(i) == 1) { - dim = new_shape.dimensions(i); - } else if (transform) { - dim = transform(i, tile_shape().dimensions(i)); - } else { - dim = tile_shape().dimensions(i); - } - new_tile_shape.add_dimensions(dim); + Shape result_shape = shape; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + (*result_shape.mutable_dimensions())[i] = + CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); } - TF_CHECK_OK( - LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); - return HloSharding::Tile(new_tile_shape, tile_assignment()); + return result_shape; } HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - - Shape sub_shape = ShapeUtil::GetSubshape(shape, index); - ShapeTree sub_shape_tree(sub_shape, Replicate()); - sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) - : sub_shape_tree.element(ShapeIndex({})); + int64 sharding_index = 0; + const Shape* sub_shape = &shape; + for (int64 idx : index) { + for (int64 i = 0; i < idx; ++i) { + sharding_index += + ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i})); + } + sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); + } + if (ShapeUtil::IsTuple(*sub_shape)) { + auto begin_it = tuple_elements_.begin() + sharding_index; + std::vector sub_shardings( + begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); + return HloSharding::Tuple(*sub_shape, sub_shardings); + } else { + return tuple_elements_[sharding_index]; + } } -tensorflow::gtl::optional HloSharding::ExtractSingleSharding() - const { +absl::optional HloSharding::ExtractSingleSharding() const { if (!IsTuple()) { return *this; } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { - return tensorflow::gtl::optional(); + return absl::nullopt; } } return tuple_elements_.front(); } size_t HloSharding::Hash() const { - if (!tuple_) { + if (tuple_) { size_t h = 0; for (const auto& element : tuple_elements_) { h = tensorflow::Hash64Combine(h, element.Hash()); @@ -489,9 +473,6 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 28575c0e75548d1a21381b37754232c5d843dfbe..9775505f8608ced3e33abe376f4922cc6a972726 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -23,12 +23,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -48,22 +48,10 @@ class HloSharding { // the input shape (one tile) assigned to a single device. static HloSharding AssignDevice(int64 device_id); - // Creates a new sharding which splits a shape into tiles each with shape - // `tile_shape`. Each tile is assigned to one device, which is specified by - // `tile_assignment`. Any tensor not a multiple of the tile size in any - // dimension is implicitly padded to the tile size. - // - // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: - // 2 1 padding - // <------><-> - // +----+----+ - // | 0 | 1 | - // +----+----+ - // - // Split into two tiles, one of which is implicitly padded by one. - static HloSharding Tile(const Shape& tile_shape, - const Array& tile_assignment) { - return HloSharding(tile_shape, tile_assignment); + // Creates a new sharding which splits a shape into tiles amongst the devices + // specified by `tile_assignment`. + static HloSharding Tile(const Array& tile_assignment) { + return HloSharding(tile_assignment); } // Creates a new sharding which splits a one-dimensional input shape into @@ -78,7 +66,7 @@ class HloSharding { // shardings must match the number of leaf nodes in tuple_shape. For // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings); + absl::Span shardings); // Creates a new sharding for a tuple type, with a single input sharding // repeated on each leaf. @@ -144,25 +132,26 @@ class HloSharding { // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. // REQUIRES: !IsTuple() - int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; + int64 DeviceForTileIndex(absl::Span index) const; - // Given a device ID, returns the offset within the input space of the + // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileOffsetForDevice(int64 device) const; + std::vector TileOffsetForDevice(const Shape& shape, + int64 device) const; - // Given a device ID, returns the limit within the input space of the + // Given a device ID, returns the limit within the specified shape of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileLimitForDevice(int64 device) const; + std::vector TileLimitForDevice(const Shape& shape, int64 device) const; // Returns the single device this op operates on. If the sharding does not // span a single device, the return value will be empty. // In order for a sharding to span a single device, every leaf sharding must // be maximal and not replicated, and the used device must match. - tensorflow::gtl::optional UniqueDevice() const; + absl::optional UniqueDevice() const; // Retrieves the unique device or fails with a CHECK. int64 GetUniqueDevice() const; @@ -193,11 +182,10 @@ class HloSharding { // be returned. If it is a tuple, and all the tuple elements are common, the // common element will be returned. Otherwise the optional will contain no // value. - tensorflow::gtl::optional ExtractSingleSharding() const; + absl::optional ExtractSingleSharding() const; bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -211,9 +199,6 @@ class HloSharding { } }; - // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() - const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } @@ -225,25 +210,15 @@ class HloSharding { return tuple_elements_; } - // Return a new sharding that can apply to the given new shape. - // If this sharding is tile-maximal, the returned sharding will be the same as - // this sharding. If this sharding is not tile-maximal, the returned - // sharding's tile size will differ: - // - Non-sharded dimensions will be adapted to be the same as `new_shape`; - // tile_dimension(i) = new_shape.dimensions(i); - // - Sharded dimensions will be kept the same unless `transform` is supplied - // in which case tile_dimension(i) = transform(i, tile_dimension(i)); - // REQUIRES: !IsTuple(). - HloSharding TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform = nullptr) const; + // Gets the tile shape. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape) const; private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({0}) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning @@ -255,15 +230,13 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({1}, device_id) {} - HloSharding(const Shape& tile_shape, const Array& tile_assignment) + explicit HloSharding(const Array& tile_assignment) : replicated_(false), maximal_(false), tuple_(false), - tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector& tuple_shardings) + explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), @@ -286,11 +259,10 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; - Shape tile_shape_; Array tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + // Only non-empty when tuple_ is true. If a tuple is empty then one entry is + // present for the root. This is a flattened list of all the leaf shardings in + // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector tuple_elements_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 94f5a3b273b2fd7e545472c42f3863f549dd3db1..e3f4a9852ace86c20610362aa6ad3c3d9c78de30 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -23,6 +24,23 @@ namespace xla { namespace { +// AssignmentKind and kUnassignedDevice are used during tuple domain sharding +// propagation in order to distinguish among three cases: +// kUnassigned: no assignment has occurred +// kAssigned: at least an assignment has occurred +// kConflict: no assignment has occurred because of conflicting propagations, +// which occurs when multiple users of an instruction have different +// shardings. +enum class AssignmentKind { kUnassigned, kAssigned, kConflict }; + +// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate +// absence of sharding information for that particular sub-sharding during +// sharding propagation. It is used to be able to express tuple shardings with +// partial information. At the end of the propagation the sharding of +// tuple-shaped instructions using kUnassignedDevice's is cleared. +// TODO(b/112883246): Centralized enum of reserved devices. +constexpr int64 kUnassignedDevice = -2; + struct PassThrough { PassThrough(HloInstruction* user, HloInstruction* operand) : user(user), operand(operand) {} @@ -117,13 +135,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, return Status::OK(); } -std::unique_ptr CloneShardingForDomain( - const HloSharding& sharding) { - auto single_sharding = sharding.ExtractSingleSharding(); +// For tuple shardings if every element have the same sharsing then we want to +// treat them as single element sharsings to insert less domain separation as a +// domain can prevent some optimizations and we want to minimize that from +// happening. +std::shared_ptr CloneShardingForDomain( + std::shared_ptr sharding) { + auto single_sharding = sharding->ExtractSingleSharding(); if (!single_sharding) { - return MakeUnique(sharding); + return sharding; } - return MakeUnique(*single_sharding); + return std::make_shared(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -142,102 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, return Status::OK(); } -// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. -// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() -// sharding will be returned. -ShapeTree GetTupleSharding(HloInstruction* tuple) { - if (tuple->has_sharding()) { - return tuple->sharding().GetAsShapeTree(tuple->shape()); +// Return the ShapeTree of the user argument. The user argument +// is assumed to be a user of the instruction argument. +// If user is a tuple instruction, return the tuple subsharding corresponding to +// the operand matching the instruction argument, because that is the +// subsharding corresponding to instruction. +ShapeTree GetShardingTreeFromUser( + const HloInstruction& instruction, const HloInstruction& user) { + if (user.opcode() == HloOpcode::kTuple) { + return user.sharding() + .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) + .GetAsShapeTree(instruction.shape()); + } + return user.sharding().GetAsShapeTree(user.shape()); +} + +// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) +// then no assignment is made. Therefore kUnassignedDevice is never propagated. +// kConflict is returned if lhs is already assigned and rhs is assigned to a +// different device. +StatusOr AssignLeafSharding(HloSharding* lhs, + const HloSharding& rhs) { + TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple()); + if (rhs.UsesDevice(kUnassignedDevice)) { + return AssignmentKind::kUnassigned; } - return ShapeTree(tuple->shape(), HloSharding::Replicate()); + if (lhs->UsesDevice(kUnassignedDevice)) { + *lhs = rhs; + return AssignmentKind::kAssigned; + } + return lhs->UniqueDevice() != rhs.UniqueDevice() + ? AssignmentKind::kConflict + : AssignmentKind::kUnassigned; } -// Retrieves the sharding of operand, asked from a user instruction which is -// within domain. If operand is a kDomain, it means that sharding argument is -// the operand sharding, otherwise the operand's own sharding will be returned. -const HloSharding* GetOperandSharding(const HloInstruction* operand, +// Assigns the whole rhs tree to lhs_tree, starting at lhs_it. +// In case of conflicting assignment AssignmentKind::kConflict is returned. In +// this case lhs_tree is partially assigned, up to the conflicting leaf. It is +// up to the caller to discard the partial assignment in case of conflict. +StatusOr AssignTreeSharding( + ShapeTree* lhs_tree, ShapeTree::iterator lhs_it, + const ShapeTree& rhs_tree) { + AssignmentKind assigned = AssignmentKind::kUnassigned; + auto rhs_it = rhs_tree.begin(); + for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end(); + ++lhs_it, ++rhs_it) { + // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it) + if (rhs_tree.IsLeaf(rhs_it->first)) { + TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first)); + TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned, + AssignLeafSharding(&lhs_it->second, rhs_it->second)); + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we return conflict to the caller. At this point + // partial assignments to lhs_tree may have been made already. It is up + // to the caller to discard the partial assignment in case of conflict. + return AssignmentKind::kConflict; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + } + TF_RET_CHECK(rhs_it == rhs_tree.end()); + return assigned; +} + +StatusOr ApplyShardingFromUsers(HloInstruction* instruction, const DomainMetadata::Domain& domain, - const HloSharding& sharding) { - DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); - // Here the user of operand is within the domain instruction set, and since it - // is user of operand, we need to look into the enter_domains set. If this is - // not a kDomain within the user domains set, then return the operand - // sharding, if any. - if (operand->opcode() != HloOpcode::kDomain || - domain.enter_domains.count(const_cast(operand)) == 0) { - return operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding& domain_sharding) { + if (instruction->users().empty()) { + // No sharding from users, use domain_sharding, after checking + // compatibility. + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + ShapeUtil::GetLeafCount(instruction->shape()) == + domain_sharding.tuple_elements().size()); + instruction->set_sharding(domain_sharding); + return true; + } + AssignmentKind assigned = AssignmentKind::kUnassigned; + // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple + // subshardings can result in a final sharding assignment containing + // kUnassignedDevice leaves, in case some tuple indexes are not used, or are + // used by users that don't have a sharding. + // Non-tuple shardings are either assigned to a real sharding, or are not + // assigned at all. As such they will never get assigned to kUnassignedDevice. + // In any case, kUnassignedDevice is never propagated, from the implementation + // of AssignLeafSharding. + ShapeTree sharding_tree( + instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(const_cast(user)) > 0) { + // If a user is a domain and it is registered in the domain exits, then + // the instruction sharding is taken directly from the domain, and no + // further users need to be visited. + instruction->set_sharding(domain_sharding); + return true; + } + if (!user->has_sharding()) { + continue; + } + AssignmentKind sub_assigned = AssignmentKind::kUnassigned; + ShapeTree user_sharding_tree = + GetShardingTreeFromUser(*instruction, *user); + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuple-shaped instructions collect individual tuple subshardings + // from the uses, and then combine them into the tuple sharding. + // If the user is a GTE its sharding concerns only the subtree of + // sharding_tree at index user->tuple_index, otherwise the whole + // sharding_tree is affected. + ShapeTree::iterator sharding_tree_begin = + user->opcode() == HloOpcode::kGetTupleElement + ? sharding_tree.find({user->tuple_index()}) + : sharding_tree.begin(); + TF_ASSIGN_OR_RETURN( + sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin, + user_sharding_tree)); + } else { + // Non-tuple shape: assign common users sharding. + TF_RET_CHECK(user_sharding_tree.leaf_count() == 1) + << "Expected non-tuple user sharding"; + TF_ASSIGN_OR_RETURN( + sub_assigned, + AssignTreeSharding(&sharding_tree, sharding_tree.begin(), + user_sharding_tree)); + } + + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we don't assign any sharding. + return false; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + + if (assigned == AssignmentKind::kAssigned) { + if (ShapeUtil::IsTuple(instruction->shape())) { + instruction->set_sharding(HloSharding::Tuple(sharding_tree)); + } else { + TF_RET_CHECK(sharding_tree.leaf_count() == 1); + instruction->set_sharding(sharding_tree.leaf_begin()->second); + } + return true; } - // At this point operand is a kDomain of the currently processed domain, so we - // can refer to sharding as the domain sharding. - return &sharding; + return false; } // Tries to propagate the sharding information into the instructions that are -// part of the domain, in a post order manner (operand propagate to user). +// part of the domain, in a reverse post order manner (users propoagate to +// instruction). StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, - const HloSharding& sharding) { + const HloSharding& domain_sharding) { int64 assigned = 0; - for (HloInstruction* instruction : domain.instructions) { + // domain.instructions are ordered in a post-order manner. As we do + // user->operand propagation we process instructions in reverse order. In so + // doing we are guaranteed to process all users before their operands. + for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend(); + ++it) { + HloInstruction* instruction = *it; if (instruction->has_sharding()) { continue; } - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* tuple = instruction->mutable_operand(0); - const HloSharding* tuple_sharding = - GetOperandSharding(tuple, domain, sharding); - if (tuple_sharding != nullptr) { - if (tuple_sharding->IsTuple()) { - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); - } else { - SetSingleSharding(instruction, *tuple_sharding); - } - ++assigned; - } - } else if (instruction->opcode() == HloOpcode::kTuple) { - int64 tuple_assigned = 0; - ShapeTree shape_tree = GetTupleSharding(instruction); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const HloSharding* operand_sharding = - GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr && - shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; - } - } - if (tuple_assigned > 0) { - HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); - VLOG(4) << " " << instruction->name() << " to sharding " - << tuple_sharding; - instruction->set_sharding(tuple_sharding); - ++assigned; - } - } else { - // If all the operand of the given instruction has the same single device - // assignment, assign that device to this instruction as well. - const HloSharding* common_sharding = nullptr; - for (const HloInstruction* operand : instruction->operands()) { - const HloSharding* operand_sharding = - GetOperandSharding(operand, domain, sharding); - if (operand_sharding != nullptr) { - if (common_sharding != nullptr && - *common_sharding != *operand_sharding) { - common_sharding = nullptr; - break; - } - common_sharding = operand_sharding; - } - } - if (common_sharding != nullptr) { - VLOG(4) << " " << instruction->name() << " to sharding " - << *common_sharding; - instruction->set_sharding(*common_sharding); - ++assigned; - } + // Take the sharding from the users. + TF_ASSIGN_OR_RETURN( + bool instruction_assigned, + ApplyShardingFromUsers(instruction, domain, domain_sharding)); + if (instruction_assigned) { + ++assigned; + VLOG(4) << " " << instruction->name() << " to sharding " + << instruction->sharding(); } } return assigned; @@ -255,83 +349,40 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; - for (;;) { - TF_ASSIGN_OR_RETURN(int64 assigned, - ApplyDomainShardingPass(domain, sharding)); - if (assigned == 0) { - break; - } - } + TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status()); + int64 unassigned = 0; for (HloInstruction* instruction : domain.instructions) { if (!instruction->has_sharding()) { LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); ++unassigned; + } else { + // Un-set sharding of tuples whose sub-sgardings are assigned to + // kUnassignedDevice. Indeed in case of doubt it is better to leave the + // entire tuple unassigned, and let the device placer decide for it. + if (instruction->sharding().UsesDevice(kUnassignedDevice)) { + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + << "Only tuples can have kUnassignedDevice sub shardings"; + instruction->clear_sharding(); + } } } // Should we error out if unassigned > 0? return Status::OK(); } -// Creates a kDomain instruction to be placed between instruction and operand. -// The kDomain instruction will be created only if the sharding differ between -// the instruction and the operand. -std::unique_ptr CreateDomain(HloInstruction* instruction, - HloInstruction* operand) { - const HloSharding* instruction_sharding = - instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* operand_sharding = - operand->has_sharding() ? &operand->sharding() : nullptr; - // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && operand_sharding == nullptr) { - return nullptr; - } - // No need for domain if they match. - if (instruction_sharding != nullptr && operand_sharding != nullptr && - ShardingMatches(*instruction_sharding, *operand_sharding)) { - return nullptr; - } - std::unique_ptr real_instruction_sharding; - std::unique_ptr real_operand_sharding; - if (instruction_sharding != nullptr) { - real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); - } - if (operand_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*operand_sharding); - } - VLOG(3) << "Creating domain:"; - VLOG(3) << " Instruction: " << instruction->name(); - VLOG(3) << " Operand: " << operand->name(); - VLOG(3) << " User side sharding: " - << (real_instruction_sharding != nullptr - ? real_instruction_sharding->ToString() - : "None"); - VLOG(3) << " Operand side sharding: " - << (real_operand_sharding != nullptr - ? real_operand_sharding->ToString() - : "None"); - - std::unique_ptr operand_side_metadata = - MakeUnique(std::move(real_operand_sharding)); - std::unique_ptr user_side_metadata = - MakeUnique(std::move(real_instruction_sharding)); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); -} - -StatusOr> ExtractOriginalCommonSharding( - tensorflow::gtl::ArraySlice instructions) { +StatusOr> ExtractOriginalCommonSharding( + absl::Span instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the // original common sharding. // All the instructions passed to this API are part of the same computation. - const HloSharding* sharding = nullptr; + std::shared_ptr sharding; for (HloInstruction* instruction : instructions) { if (instruction->has_sharding()) { if (sharding == nullptr) { - sharding = &instruction->sharding(); + sharding = instruction->sharding_ptr(); } else { TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) << "Sharding " << *sharding << " does not match the one in " @@ -340,10 +391,10 @@ StatusOr> ExtractOriginalCommonSharding( } } if (sharding == nullptr) { - return std::unique_ptr(); + return std::shared_ptr(); } VLOG(4) << "Extracted sharding is " << *sharding; - return CloneShardingForDomain(*sharding); + return CloneShardingForDomain(sharding); } } // namespace @@ -351,9 +402,9 @@ StatusOr> ExtractOriginalCommonSharding( std::unique_ptr ShardingMetadata::Clone() const { std::unique_ptr sharding; if (sharding_ != nullptr) { - sharding = MakeUnique(*sharding_); + sharding = absl::make_unique(*sharding_); } - return MakeUnique(std::move(sharding)); + return absl::make_unique(std::move(sharding)); } bool ShardingMetadata::Matches(const DomainMetadata& other) const { @@ -371,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { : false; } +size_t ShardingMetadata::Hash() const { + if (sharding_ != nullptr) { + return sharding_->Hash(); + } + return static_cast(0x297814aaad196e6dULL); +} + string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } @@ -397,7 +455,7 @@ Status ShardingMetadata::NormalizeShardingDomain( TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); } } else { - TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + TF_ASSIGN_OR_RETURN(std::shared_ptr sharding, ExtractOriginalCommonSharding(domain.instructions)); if (sharding != nullptr) { VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); @@ -409,9 +467,75 @@ Status ShardingMetadata::NormalizeShardingDomain( return Status::OK(); } -std::unique_ptr CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand) { - return CreateDomain(instruction, operand); +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + auto instruction_sharding = instruction->sharding_ptr(); + auto root_sharding = root->sharding_ptr(); + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && root_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { + return nullptr; + } + + if (instruction_sharding != nullptr) { + instruction_sharding = CloneShardingForDomain(instruction_sharding); + } + if (root_sharding != nullptr) { + root_sharding = CloneShardingForDomain(root_sharding); + } + + auto it = domain_cse_map_.find({operand, instruction_sharding}); + if (it != domain_cse_map_.end()) { + return it->second; + } + + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (instruction_sharding != nullptr ? instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (root_sharding != nullptr ? root_sharding->ToString() : "None"); + + HloInstruction* domain = + operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, + absl::make_unique(root_sharding), + absl::make_unique(instruction_sharding))); + domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding}, + domain); + return domain; +} + +bool ShardingDomainCreator::DomainCseMapKey::operator==( + const ShardingDomainCreator::DomainCseMapKey& other) const { + if (instruction != other.instruction) { + return false; + } + if (sharding == nullptr && other.sharding == nullptr) { + return true; + } + if (sharding == nullptr || other.sharding == nullptr) { + return false; + } + return *sharding == *other.sharding; +} + +size_t ShardingDomainCreator::DomainCseMapHasher::operator()( + const ShardingDomainCreator::DomainCseMapKey& key) const { + return tensorflow::Hash64Combine( + std::hash{}(key.instruction), + key.sharding ? key.sharding->Hash() + : static_cast(0x297814aaad196e6dULL)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index 5e01fc0e22ae8f3421c2cb5790adf44b1200a804..e3ae82a070643895f2ecac0e64073a88b592f7c1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -16,31 +16,33 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { // A DomainMetadata implementation that internally wraps a sharding attribute. class ShardingMetadata : public DomainMetadata { public: - explicit ShardingMetadata(std::unique_ptr sharding) + explicit ShardingMetadata(std::shared_ptr sharding) : sharding_(std::move(sharding)) {} std::unique_ptr Clone() const override; - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override; + size_t Hash() const override; + string ToString() const override; const HloSharding* sharding() const { return sharding_.get(); } - static tensorflow::StringPiece KindName() { return "sharding"; } + static absl::string_view KindName() { return "sharding"; } static StatusOr ToShardingMetadata( const DomainMetadata* metadata); @@ -55,15 +57,33 @@ class ShardingMetadata : public DomainMetadata { const DomainMetadata* metadata); private: - std::unique_ptr sharding_; + std::shared_ptr sharding_; }; -// Given an HLO graph edge between instruction and one of its operands, creates -// a ShardingMetadata based kDomain instruction if the sharding between -// instruction and operand changes. Returns nullptr if there is no need for a -// domain separation. -std::unique_ptr CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand); +// If the sharding between root and instruction changes then returns a +// ShardingMetadata based kDomain instruction what can be used to separate +// operand and instruction. +// Returns nullptr if there is no need for a domain separation. +class ShardingDomainCreator { + public: + HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand); + + private: + // Map from instruction and user sharding to domain users to CSE identical + // domains. + struct DomainCseMapKey { + const HloInstruction* instruction; + std::shared_ptr sharding; + + bool operator==(const DomainCseMapKey& other) const; + }; + struct DomainCseMapHasher { + size_t operator()(const DomainCseMapKey& key) const; + }; + std::unordered_map + domain_cse_map_; +}; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index aebda562d38a2e46be1ba1572a92213afeab40e5..80634677e78e4a35dcb9bf7de018a88122c3c030 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -29,8 +29,8 @@ limitations under the License. namespace xla { namespace { -Array MakeArray(tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice contents) { +Array MakeArray(absl::Span dimensions, + absl::Span contents) { Array a(dimensions); std::copy(contents.begin(), contents.end(), a.begin()); return a; @@ -39,7 +39,6 @@ Array MakeArray(tensorflow::gtl::ArraySlice dimensions, class HloShardingTest : public HloTestBase {}; TEST_F(HloShardingTest, Replicate) { - Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); HloSharding sharding = HloSharding::Replicate(); EXPECT_TRUE(sharding.IsReplicated()); EXPECT_TRUE(sharding.IsTileMaximal()); @@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, Tile) { { // Test should fail because of a duplicate tile assignment. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), /*num_devices=*/4)); } { // Test should fail because of more devices used then `num_device`. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); } - { - // Test should fail because the total tiled size in dimension 0 is 4 but we - // have 6 elements along that dimensions. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), - /*num_devices=*/4)); - } - { // Test should pass. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + Shape shape = ShapeUtil::MakeShape(U32, {4, 5}); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), /*num_devices=*/5)); @@ -118,15 +102,26 @@ TEST_F(HloShardingTest, Tile) { EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); - EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector{0, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector{0, 3})); - EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector{2, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector{2, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0), + (std::vector{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3), + (std::vector{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2), + (std::vector{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1), + (std::vector{2, 3})); EXPECT_FALSE(sharding.HasUniqueDevice()); } } +// Tests that empty tuple is supported. +TEST_F(HloShardingTest, EmptySingleTuple) { + HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), + HloSharding::AssignDevice(0)); + EXPECT_TRUE(sharding.ExtractSingleSharding()); +} + TEST_F(HloShardingTest, NestedTuple) { // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ @@ -135,8 +130,7 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); - HloSharding tiled_sharding = HloSharding::Tile( - ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + HloSharding tiled_sharding = HloSharding::Tile(Array({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); @@ -187,32 +181,11 @@ TEST_F(HloShardingTest, Hash) { } { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); } - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 1, 2})); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - HloSharding default_sharding = HloSharding::Replicate(); { ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), @@ -259,19 +232,6 @@ TEST_F(HloShardingTest, Hash) { } } -TEST_F(HloShardingTest, TransformShardedTileShapeTest) { - HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); - HloSharding result = sharding.TransformShardedTileShape( - ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), - [](int dim, int value) { return dim * 111; }); - HloSharding expected = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), - Array4D({{{{0, 1}, {2, 3}}}})); - EXPECT_EQ(result, expected); -} - TEST_F(HloShardingTest, ToStringReplicatedTest) { HloSharding sharding = HloSharding::Replicate(); EXPECT_EQ(sharding.ToString(), "{replicated}"); @@ -284,9 +244,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) { TEST_F(HloShardingTest, ToStringTiledTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), - Array3D({{{2, 3}}, {{5, 7}}})); - EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); + HloSharding::Tile(Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}"); } TEST_F(HloShardingTest, ToStringTupleTest) { @@ -294,21 +253,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {7, 25}), ShapeUtil::MakeShape(S32, {9, 11})}), - {HloSharding::Replicate(), - HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), - Array2D({{3, 5}})), + {HloSharding::Replicate(), HloSharding::Tile(Array2D({{3, 5}})), HloSharding::AssignDevice(3)}); EXPECT_EQ(sharding.ToString(), - "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); + "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}"); } TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding::Tile(Array4D({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); + EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}"); } TEST_F(HloShardingTest, ParseHloString) { @@ -319,8 +275,7 @@ TEST_F(HloShardingTest, ParseHloString) { }; check(HloSharding::Replicate()); check(HloSharding::AssignDevice(2)); - check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}}))); + check(HloSharding::Tile(Array4D({{{{0}, {1}}}}))); // Empty tuple. One sharding is required for empty tuples, as we need to be // able to assign sharding to them, even though they have no leaves. check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), @@ -332,8 +287,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})}); check(HloSharding::Tuple( - tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + tuple_shape, {HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)})); } { @@ -343,8 +297,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})})}); std::vector leaf_shardings = { - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)}; ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); // Assign leaf_shardings to sharding_tree leaves. diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index 2ef38821af632180714911c0ff22731fd559b915..fa34bddde1a47b520f7f96361d155e4017e44e60 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -22,9 +22,9 @@ namespace xla { // Unify subcomputations of a `HloModule`: if any computations are equal, choose // one arbitrarily to use and delete the others. -class HloSubcomputationUnification : public HloPassInterface { +class HloSubcomputationUnification : public HloModulePass { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "subcomputation-unification"; } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d..487653344976a10e18ba667085525ba1ecbb8612 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -21,28 +23,25 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; -using ::tensorflow::TensorShapeProto; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; -using ::tensorflow::str_util::Join; namespace xla { namespace hlo_graph_dumper { namespace { +using absl::StrAppend; +using absl::StrCat; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorShapeProto; + string GetOpDefName(const HloInstruction* instruction) { string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); + tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); if (instruction->opcode() == HloOpcode::kFusion) { string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + StrAppend(&name, absl::string_view(fusion_name).substr(1)); } return name; } @@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); + "{", + absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), + "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloTestBase { +class HloTfGraphBuilderTest : public HloVerifiedTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h index 533429608bc2e13626a3e746fbe465398e1f4bb4..4458c251dee4af365e39027dd4289925c8890efd 100644 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -44,7 +44,6 @@ enum class TokKind { kRparen, // ( ) kArrow, // -> - kComment, // /*xxx*/ // Keywords kw_HloModule, diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7fd99fc93050b386c5ad24e6dcd2fea1bf652c3f..85494877023fa3812973e993a349a7559706ab5d 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -30,16 +32,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; const Shape& HloPosition::shape() const { return ShapeUtil::GetSubshape(instruction->shape(), index); @@ -132,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); + case HloOpcode::kDomain: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; @@ -150,7 +150,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace void HloValue::SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions) { + absl::Span positions) { CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; // The positions must be unique and should not contain the defining position @@ -216,14 +216,14 @@ void HloValueSet::SortAndUniquifyValues() { } string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return StrCat( + "HloValueSet: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } -bool HloValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { +bool HloValueSet::AssignUnionOf(absl::Span inputs) { HloValueSet union_set; for (const HloValueSet* input : inputs) { for (const HloValue* value : input->values()) { @@ -254,7 +254,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { } bool InstructionValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK_GT(inputs.size(), 0); for (int i = 1; i < inputs.size(); ++i) { DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index a1151f65e07dffdcd52f645f61dcc9b4f26459c0..b6670d409b92e8be42f5cdb40fba8d662ae83958 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -108,8 +108,7 @@ class HloValue : public BufferValue { // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not // be included in 'positions' as this is set at construction time. - void SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions); + void SetPositionsAndComputeUses(absl::Span positions); // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -186,14 +185,14 @@ class HloValueSet { public: HloValueSet() = default; - explicit HloValueSet(tensorflow::gtl::ArraySlice values) + explicit HloValueSet(absl::Span values) : values_(values.begin(), values.end()) { SortAndUniquifyValues(); } // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf(tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); // Return the vector of HloValues in the set. Values in the vector are unique // and stably sorted by value id. @@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree { // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf( - tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); string ToString() const; }; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 1a8c206aaf35c1391923f047469f332b44d82e67..6eb66589048c1a8d6ccfed73c0f7e32f5fe6e568 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,11 +15,13 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -84,7 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers())); + convolution->feature_group_count(), convolution->window(), + convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } @@ -105,6 +108,20 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } +Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferAllToAllTupleShape(operand_shapes)); +} + +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( + hlo->operand(0)->shape())); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -112,46 +129,35 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -namespace { - -Status CheckIsTokenOperand(const HloInstruction* instruction, - int64 operand_no) { +Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( - "Expected operand %lld to be token-shaped, actual shape is " + "Expected operand %d to be token-shaped, actual shape is " "%s:\n%s", - operand_no, ShapeUtil::HumanString(token->shape()).c_str(), - instruction->ToString().c_str()); + operand_no, StringifyShape(token->shape()), instruction->ToString()); } return Status::OK(); } -Status CheckOperandAndParameter(const HloInstruction* instruction, - int64 operand_number, - const HloComputation* computation, - int64 parameter_number) { +Status ShapeVerifier::CheckOperandAndParameter( + const HloInstruction* instruction, int64 operand_number, + const HloComputation* computation, int64 parameter_number) { const HloInstruction* operand = instruction->operand(operand_number); const HloInstruction* parameter = computation->parameter_instruction(parameter_number); - if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", - operand->ToString().c_str(), - parameter->ToString().c_str(), - instruction->ToString().c_str()); + operand->ToString(), parameter->ToString(), + instruction->ToString()); } return Status::OK(); } -} // namespace - Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); - // Infeed has an optional single token operand. - // TODO(b/80000000): Update when token is not optional. - if (infeed->operand_count() == 1) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); - } + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); // The output of infeed is a tuple containing the data value and a token. return CheckShape(infeed, @@ -161,31 +167,81 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { HloOutfeedInstruction* outfeed = Cast(instruction); - // Outfeed has an optional token operand (operand 1). - // TODO(b/80000000): Update when token is not optional. - if (outfeed->operand_count() == 2) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); - } + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. - if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), - outfeed->operand(0)->shape())) { + if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed shape to be compatible with operand's shape %s, " + "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), - outfeed->ToString().c_str()); + StringifyShape(outfeed->operand(0)->shape()), + StringifyShape(outfeed->outfeed_shape()), outfeed->ToString()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } -Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return Status::OK(); +bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, + const Shape& shape_1, + const Shape& result_shape) { + return ShapeUtil::SameElementType(shape_0, shape_1) && + (ShapeUtil::SameElementType(shape_0, result_shape) || + (allow_mixed_precision_ && + ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0, + result_shape))); } -Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleRng(HloInstruction* instruction) { + if (instruction->operand_count() != 2) { + return InternalError("Expected two operands for Rng instruction: %s", + instruction->ToString()); + } + + const Shape& shape_0 = instruction->operand(0)->shape(); + const Shape& shape_1 = instruction->operand(1)->shape(); + if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { + return InternalError( + "Expected scalar types for the two operands of Rng instruction: %s", + instruction->ToString()); + } + + if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { + return InternalError( + "Expected compatible element types for the result and the two operands" + " of Rng instruction: %s", + instruction->ToString()); + } + + PrimitiveType element_type = shape_0.element_type(); + switch (instruction->random_distribution()) { + case RNG_UNIFORM: + if (!primitive_util::IsFloatingPointType(element_type) && + !primitive_util::IsIntegralType(element_type) && + element_type != PRED) { + return InternalError( + "Element type not supported." + " Expected element to be of floating point type, integral type or" + " predicate type for RngUniform: %s", + instruction->ToString()); + } + break; + + case RNG_NORMAL: + if (!primitive_util::IsFloatingPointType(element_type)) { + return InternalError( + "Element type not supported." + " Expected element to be FloatingPointType for RngNormal: %s", + instruction->ToString()); + } + break; + default: + return InternalError( + "Invalid Rng distribution %s", + RandomDistribution_Name(instruction->random_distribution())); + } + + return Status::OK(); +} Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -200,8 +256,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError( "Expected sort to have to have the same dimensions for the keys and " "the values. Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + StringifyShape(sort->operand(0)->shape()), + StringifyShape(sort->operand(1)->shape())); } return CheckVariadicShape(sort); } @@ -210,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } -Status ShapeVerifier::HandleIota(HloInstruction* iota) { - return ShapeUtil::Rank(iota->shape()) == 1 - ? Status::OK() - : InternalError("Iota only supports arrays of rank 1."); +Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + const int64 rank = ShapeUtil::Rank(iota->shape()); + if (rank == 0) { + return InternalError("Iota does not support scalars."); + } + int64 iota_dimension = iota->iota_dimension(); + if (iota_dimension >= rank) { + return InternalError( + "The iota dimension cannot go beyond the operation rank."); + } + return Status::OK(); } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { @@ -224,14 +288,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsArray(reduce->shape())) { - return InvalidArgument("Variadic reduce is not supported."); + std::vector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); } - return CheckShape( - reduce, - ShapeInference::InferReduceShape( - {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, - reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); + return CheckShape(reduce, ShapeInference::InferReduceShape( + operand_shapes, reduce->dimensions(), + reduce->to_apply()->ComputeProgramShape())); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { @@ -275,7 +338,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { + for (HloInstruction* fused_param : fusion->fused_parameters()) { + int64 param_no = fused_param->parameter_number(); + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { + return InternalError( + "Shape mismatch between parameter number %d and its operand in " + "%s.", + param_no, fusion->ToString().c_str()); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { @@ -357,12 +431,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapeUtil::Compatible(conditional_shape, - ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - ShapeUtil::HumanString(conditional_shape).c_str()); + StringifyShape(conditional_shape)); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -454,9 +527,9 @@ namespace { // inputs. Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { switch (instruction->opcode()) { - // White list the following opcodes for mixed-precision check, because they - // involve data pass through or grouping via tuples, where the precisions - // of buffers can be different. + // White list the following opcodes for mixed-precision check, because + // they involve data pass through or grouping via tuples, where the + // precisions of buffers can be different. case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: @@ -493,7 +566,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", - instruction->ToString().c_str()); + instruction->ToString()); } return Status::OK(); })); @@ -510,7 +583,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather, ShapeInference::InferGatherShape( gather->operand(0)->shape(), gather->operand(1)->shape(), - gather->gather_dimension_numbers(), gather->gather_window_bounds())); + gather->gather_dimension_numbers(), gather->gather_slice_sizes())); } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { @@ -540,53 +613,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } // Check if the output shape matches the expected shape. - bool compatible; + // // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. - switch (instruction->opcode()) { - case HloOpcode::kTupleSelect: - // TupleSelect only defines the top-level buffer, which in this case is - // the tuple, so we cannot allow mixed precision. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: - // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed - // precision is disallowed. - case HloOpcode::kConstant: - case HloOpcode::kBitcast: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kConvert: - case HloOpcode::kCustomCall: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kWhile: - // The above opcodes should match the expected shapes exactly. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - default: - if (allow_mixed_precision_) { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } else { - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } - } - if (!compatible) { + bool equal = [&] { + switch (instruction->opcode()) { + // The opcodes below can't have implicit layout conversions, nor can they + // implicitly transform f32 -> bf16. Fundamentally these are either + // reinterpreting existing data (e.g. kBitcast) or shuffling data around + // without modifying it (e.g. kGetTupleElement, kTupleSelect). + case HloOpcode::kBitcast: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return ShapesSame(instruction->shape(), inferred_shape); + + // We allow arbitrary layout and f32->bf16 transformations on all other + // instructions, although this may be made more strict pending discussion + // in b/112709536. + default: + if (allow_mixed_precision_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(), + inferred_shape); + } else { + return ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + }(); + if (!equal) { return InternalError( - "Expected instruction to have shape compatible with %s, actual " + "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(inferred_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); + StringifyShape(inferred_shape), StringifyShape(instruction->shape()), + instruction->ToString()); } return Status::OK(); } @@ -628,17 +699,17 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } -string ComputationsToString( - tensorflow::gtl::ArraySlice computations) { - return tensorflow::str_util::Join( - computations, ",", [](string* s, const HloComputation* computation) { - s->append(computation->name()); - }); +string ComputationsToString(absl::Span computations) { + return absl::StrJoin(computations, ",", + [](string* s, const HloComputation* computation) { + s->append(computation->name()); + }); } // Verifies various invariants about the structure of the HLO: // -// (1) each instruction has a non-null parent() set to the HloComputation which +// (1) each instruction has a non-null parent() set to the HloComputation +// which // contains it. // // (2) each computation has a non-null parent() set to the HloModule which @@ -650,31 +721,31 @@ Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { return InternalError("Computation %s has a null parent pointer", - computation->name().c_str()); + computation->name()); } if (computation->parent() != module) { return InternalError( "Computation %s parent() does not point to parent module", - computation->name().c_str()); + computation->name()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { return InternalError("Instruction %s has a null parent pointer", - instruction->name().c_str()); + instruction->name()); } if (instruction->parent() != computation) { return InternalError( "Instruction %s parent() does not point to parent computation", - instruction->name().c_str()); + instruction->name()); } } } // Check that operands are in the same computation separately from verifying - // parent() correctness so conditions like a null HloInstruction::parent() are - // identified and reported explicitly above rather than reporting a mismatched - // operand. + // parent() correctness so conditions like a null HloInstruction::parent() + // are identified and reported explicitly above rather than reporting a + // mismatched operand. for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { for (int i = 0; i < instruction->operand_count(); ++i) { @@ -683,9 +754,8 @@ Status VerifyHloStructure(HloModule* module) { return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", - i, operand->name().c_str(), instruction->name().c_str(), - operand->parent()->name().c_str(), - instruction->parent()->name().c_str()); + i, operand->name(), instruction->name(), + operand->parent()->name(), instruction->parent()->name()); } } } @@ -698,13 +768,14 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { return InternalError( - "Instruction of fused computation does not match expected instruction " + "Instruction of fused computation does not match expected " + "instruction " "%s.", - fusion->ToString().c_str()); + fusion->ToString()); } - // Fused root instruction and fused parameters must all be owned by the fusion - // computation. + // Fused root instruction and fused parameters must all be owned by the + // fusion computation. bool root_owned = false; const std::vector& fused_parameters = fusion->fused_parameters(); @@ -714,7 +785,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_root == instruction) { if (root_owned) { return InternalError("Root appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } root_owned = true; } @@ -722,7 +793,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return InternalError("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } parameter_owned[i] = true; } @@ -730,76 +801,68 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } if (!root_owned) { return InternalError("Root not found in computation of %s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { return InternalError("Parameter %d not found in computation of %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return InternalError("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", fusion->ToString()); } - // All uses of fused instructions must be in the fusion computation, and every - // non-root instruction must have at least one use. + // All uses of fused instructions must be in the fusion computation, and + // every non-root instruction must have at least one use. for (auto* instruction : fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { return InternalError("Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), - fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { return InternalError( "Non-root instruction %s in %s may not have external users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } } } } // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. + // (shapes equal) with their respective operand. CHECK_EQ(fusion->operands().size(), fused_parameters.size()); std::vector parameter_numbers(fused_parameters.size(), false); for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return InternalError("Unexpected negative parameter number %lld in %s.", - param_no, fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %d in %s.", + param_no, fusion->ToString()); } if (param_no >= fused_parameters.size()) { return InternalError( - "Unexpected parameter number %lld in %s: higher then number of " + "Unexpected parameter number %d in %s: higher then number of " "parameters %lu.", - param_no, fusion->ToString().c_str(), fused_parameters.size()); + param_no, fusion->ToString(), fused_parameters.size()); } if (parameter_numbers[param_no]) { return InternalError( - "Did not expect parameter number %lld more than once in %s.", - param_no, fusion->ToString().c_str()); + "Did not expect parameter number %d more than once in %s.", param_no, + fusion->ToString()); } parameter_numbers[param_no] = true; - if (!ShapeUtil::Compatible(fused_param->shape(), - fusion->operand(param_no)->shape())) { - return InternalError( - "Shape mismatch between parameter number %lld and its operand in %s.", - param_no, fusion->ToString().c_str()); - } } // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { return InternalError("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } @@ -814,18 +877,18 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { auto* while_body = instruction->while_body(); if (while_cond->num_parameters() != 1) { return FailedPrecondition( - "While condition must have exactly 1 parameter; had %lld : %s", - while_cond->num_parameters(), while_cond->ToString().c_str()); + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); } if (while_body->num_parameters() != 1) { return FailedPrecondition( - "While body must have exactly 1 parameter; had %lld : %s", - while_body->num_parameters(), while_body->ToString().c_str()); + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); } if (instruction->operand_count() != 1) { return FailedPrecondition( - "While loop must have exactly one operand; had %lld : %s", - instruction->operand_count(), instruction->ToString().c_str()); + "While loop must have exactly one operand; had %d : %s", + instruction->operand_count(), instruction->ToString()); } return Status::OK(); } @@ -833,16 +896,14 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { if (instruction->true_computation()->num_parameters() != 1) { return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %lld", - instruction->true_computation()->name().c_str(), - instruction->ToString().c_str(), + "True computation %s of %s must have 1 parameter insted of %d", + instruction->true_computation()->name(), instruction->ToString(), instruction->true_computation()->num_parameters()); } if (instruction->false_computation()->num_parameters() != 1) { return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %lld", - instruction->false_computation()->name().c_str(), - instruction->ToString().c_str(), + "False computation %s of %s must have 1 parameter insted of %d", + instruction->false_computation()->name(), instruction->ToString(), instruction->false_computation()->num_parameters()); } return Status::OK(); @@ -855,11 +916,11 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." - "Found non-compatible shapes for instruction %s.\n" + "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", - HloOpcodeString(instruction->opcode()).c_str(), - ShapeUtil::HumanString(out_shape).c_str(), - ShapeUtil::HumanString(operand_shape).c_str()); + HloOpcodeString(instruction->opcode()), + ShapeUtil::HumanString(out_shape), + ShapeUtil::HumanString(operand_shape)); } } return Status::OK(); @@ -890,7 +951,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { if (ShapeContainsToken(param->shape())) { return InternalError( "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape()).c_str()); + ShapeUtil::HumanString(param->shape())); } } return Status::OK(); @@ -902,15 +963,16 @@ Status CheckSameChannel(const HloInstruction* instr1, if (instr1->channel_id() != instr2->channel_id()) { return InternalError( "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); } return Status::OK(); } -// Checks if the given two instructions have the same is_host_transfer attribute -// value. Intsructions must be send/recv instructions or their 'done' variant. +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. Status CheckSameIsHostTransfer(const HloInstruction* instr1, const HloInstruction* instr2) { const HloSendRecvInstruction* send_recv1 = @@ -921,9 +983,10 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, TF_RET_CHECK(send_recv2 != nullptr); if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { return InternalError( - "Expected instructions to have the same is-host-transfer property: %s, " + "Expected instructions to have the same is-host-transfer property: " + "%s, " "%s ", - instr1->ToString().c_str(), instr2->ToString().c_str()); + instr1->ToString(), instr2->ToString()); } return Status::OK(); } @@ -940,11 +1003,12 @@ Status VerifySendsAndRecvs(const HloModule& module) { host_channels.insert({sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: %s " + "Channel %d is used for multiple host send/recv instructions: " + "%s " "and " "%s", - sendrecv->channel_id(), sendrecv->ToString().c_str(), - it_inserted.first->second->ToString().c_str()); + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); } } @@ -993,6 +1057,7 @@ Status VerifySendsAndRecvs(const HloModule& module) { } // namespace StatusOr HloVerifier::Run(HloModule* module) { + TF_RET_CHECK(!module->name().empty()); TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); @@ -1003,9 +1068,9 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RET_CHECK(instruction->parent() == computation); if (instruction->opcode() == HloOpcode::kFusion) { TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK( - ContainersEqual(instruction->called_computations(), - {instruction->fused_instructions_computation()})) + TF_RET_CHECK(instruction->called_computations() == + absl::Span( + {instruction->fused_instructions_computation()})) << "Fusion HLO calls computations other than the " "fused_instructions_computation: " << instruction->ToString() @@ -1059,6 +1124,11 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + // If the module has a schedule, it must be valid. + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Verify()); + } + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7feddaeabf9f944ed9cd4f5672ef63a7f9da2e40..0cde4a31af72e81829723c564f59edc362f73335 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/shape_inference.h" namespace xla { @@ -27,9 +28,9 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier() : allow_mixed_precision_(false) {} - explicit ShapeVerifier(bool allow_mixed_precision) - : allow_mixed_precision_(allow_mixed_precision) {} + explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; @@ -45,6 +46,8 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -62,7 +65,6 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; - Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( @@ -105,6 +107,42 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: + // Helpers that switch on layout_sensitive_. + bool ShapesSame(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::Equal(a, b) + : ShapeUtil::Compatible(a, b); + } + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) + : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + string StringifyShape(const Shape& s) { + return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) + : ShapeUtil::HumanString(s); + } + + // Checks that the given operand of the given instruction is of type TOKEN. + Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no); + + // Checks that the shape of the given operand of the given instruction matches + // the given parameter of the given computation. + Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number); + + // Returns true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand shapes + // or mixed precision is allowed and the result shape and the operand shapes + // have floating point element types. + bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, + const Shape& result_shape); + + // If the verifier is layout-sensitive, shapes must be equal to what's + // expected. Otherwise, the shapes must simply be compatible. + bool layout_sensitive_; + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -113,18 +151,14 @@ class ShapeVerifier : public DfsHloVisitor { // HLO pass that verifies invariants of HLO instructions for each computation in // the module. -class HloVerifier : public HloPassInterface { +class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function()>; - // Uses standard shape inference. - explicit HloVerifier() - : shape_verifier_factory_( - [] { return MakeUnique(false); }) {} - - explicit HloVerifier(bool allow_mixed_precision) - : shape_verifier_factory_([allow_mixed_precision] { - return MakeUnique(allow_mixed_precision); + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { + return absl::make_unique(layout_sensitive, + allow_mixed_precision); }) {} // Uses custom shape verification. @@ -132,10 +166,9 @@ class HloVerifier : public HloPassInterface { : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; - tensorflow::StringPiece name() const override { return "verifier"; } + absl::string_view name() const override { return "verifier"; } - // Note: always returns false (no instructions are ever modified by this - // pass). + // Never returns true; no instructions are ever modified by this pass. StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 04c6ba3eeb92bad2b5b69f7f56e73e1f7a8148aa..8f0423bb1c72ceb209437116a898d027f4d2c657 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -34,7 +34,21 @@ namespace { using ::testing::HasSubstr; -using HloVerifierTest = HloTestBase; +// This class cannot be converted to use HloVerifiedTestBase. It explicitly +// uses HloTestBase to create and test malformed HLOs. +class HloVerifierTest : public HloTestBase { + public: + HloVerifierTest() + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} +}; + +class HloVerifierTestAllowMixedPrecision : public HloTestBase { + public: + HloVerifierTestAllowMixedPrecision() + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} +}; TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); @@ -174,5 +188,175 @@ ENTRY entry { HasSubstr("shape does not match parameter")); } +TEST_F(HloVerifierTest, RngOpnd0NotScalar) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOpnd0NotScalar { + constant.0 = f32[] constant(0) + constant.1 = f16[2] constant({1, 3}) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1), + distribution=rng_uniform + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type")); +} + +TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f16[] constant(1) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, RngElementTypeNotSupported) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngElementTypeNotSupported { + constant.0 = s32[] constant(0) + constant.1 = s32[] constant(1) + ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); +} + +TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +// Simple module containing a convolution as the root. +static const char* const kConvHloString = R"( +HloModule module +ENTRY entry_computation { + param0 = f16[128,128,56,56] parameter(0) + param1 = f16[3,3,128,128] parameter(1) + zero_f16 = f16[] constant(0) + ROOT conv = f16[128,128,28,28] convolution(param0, param1), + window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01 +})"; + +TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_window_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive window dilation factor")); +} + +TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_base_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive base area dilation factor")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index bb5b40a8a87c5eab5a5b1599581a81bbd064511b..e76b93107c923b41666f6b0a388dda143a8cb50a 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -14,27 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -using tensorflow::strings::Appendf; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; +using absl::StrFormat; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; - Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n", - computation_name_.c_str(), - HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); + StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n", + computation_name_, + HumanReadableElapsedTime(CyclesToSeconds(total_cycles_))); int64 cumulative_cycles = 0; auto print_op = [&](const OpInfo& op, bool is_total = false) { @@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const { if (op.bytes_accessed > op.cycles) { bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = Printf("%.3fB/cycle", bpc); + bytes_per_cycle = StrFormat("%.3fB/cycle", bpc); } } @@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const { // columns in the output. cycles_percent_str = "100.% 100Σ"; } else { - cycles_percent_str = - Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); + cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent, + cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf( + StrAppendFormat( &s, - "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " "%16s :: %s\n", - op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles), op.optimal_seconds < 0 ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), + op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + bytes_per_sec, bytes_per_cycle, op.name); }; float optimal_seconds_sum = 0.0; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 6f56c3aa82e9d1c942fd67ff7a5948cf2e54370d..925111fa1f1e48650b0089f402d92e431043eabe 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,10 +29,10 @@ namespace xla { // computation, suitable for consumption by humans. class HumanReadableProfileBuilder { public: - explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, + explicit HumanReadableProfileBuilder(absl::string_view computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(std::string(computation_name)), + : computation_name_(computation_name), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -43,15 +43,13 @@ class HumanReadableProfileBuilder { // Adds an operation to the profile. If you don't know the number of // floating-point ops or bytes touched by the op, or if you don't know how // fast it would run optimally, pass -1 for that param. - void AddOp(tensorflow::StringPiece op_name, - tensorflow::StringPiece short_name, - tensorflow::StringPiece category, int64 cycles, int64 flop_count, + void AddOp(absl::string_view op_name, absl::string_view short_name, + absl::string_view category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back({std::string(op_name), std::string(short_name), - std::string(category), cycles, flop_count, - transcendental_count, bytes_accessed, - optimal_seconds}); + op_infos_.push_back({string(op_name), string(short_name), string(category), + cycles, flop_count, transcendental_count, + bytes_accessed, optimal_seconds}); } // Gets the human-readable profile. diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index aa325dc8a353c5bfbfded0c2774c66bfcc71c9cb..9c48b7db613b049536c76237b4cfebbbc47448f3 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -25,12 +25,12 @@ namespace xla { // Pass which replaces all implicit broadcasts with their equivalent sequence of // explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloPassInterface { +class ImplicitBroadcastRemover : public HloModulePass { public: ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "implicit-broadcast-remover"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8b2df3256776a7d77517daff1fe282b0dbde7045..06f0e1ed25e71659a61e6de8a84e52cf70064eae 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -14,13 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -31,32 +34,29 @@ using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; -using tensorflow::gtl::ArraySlice; -using tensorflow::str_util::Join; +using absl::StrJoin; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { auto* unknown_tensor = root->as(); - return tensorflow::strings::StrCat("%", - unknown_tensor->instruction().name()); + return absl::StrCat("%", unknown_tensor->instruction().name()); } case Array::kConstant: { if (print_constants) { string contents = root->as()->literal()->ToString(); - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, - ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + " ", contents, ")"); } - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + ")"); } case Array::kReshaped: { ReshapedArray* reshaped_array = root->as(); - return tensorflow::strings::StrCat( + return absl::StrCat( "(reshape ", ToString(reshaped_array->operand(), print_constants), " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); } @@ -67,11 +67,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { string name = root->kind() == Array::kScalarIndexedConstant ? "scalar-indexed-const" : "scalar-indexed"; - return tensorflow::strings::StrCat( + return absl::StrCat( "(", name, " ", ToString(indexed_array->source(), print_constants), " ", ToString(indexed_array->indices(), print_constants), " ", indexed_array->source_dim(), "->[", - Join(indexed_array->output_dims(), ","), "])"); + StrJoin(indexed_array->output_dims(), ","), "])"); } } } @@ -92,7 +92,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( // Depth first search over the DAG, invoking ComputeArrayFor in post order. // The HLO instructions already in the cache are considered leaves. - gtl::InlinedVector stack; + absl::InlinedVector stack; enum DfsState { kDiscovered, kVisited }; gtl::FlatMap dfs_state_map; @@ -153,7 +153,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), - instr->gather_window_bounds(), + instr->gather_slice_sizes(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else if (instr->opcode() == HloOpcode::kReshape) { @@ -165,6 +165,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + instr->precision_config(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else { @@ -185,7 +186,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape) { + absl::Span output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). // `source` is the inner Gather(A, X). @@ -251,24 +252,22 @@ StatusOr IndexedArrayAnalysis::FoldGatherOfGather( StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, Array* source, - Array* indices) { + absl::Span slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } - CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + CHECK_EQ(dim_numbers.start_index_map_size(), 1); - // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should - // it become relevant. + // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here, + // should it become relevant. - if (dim_numbers.elided_window_dims_size() != 1 || - dim_numbers.elided_window_dims(0) != - dim_numbers.gather_dims_to_operand_dims(0)) { + if (dim_numbers.collapsed_slice_dims_size() != 1 || + dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) { VLOG(3) << "ComputeArrayForGather: gather operations must elide " - "gather_dims_to_operand_dims[0] and " - "gather_dims_to_operand_dims[0] only"; + "start_index_map[0] and " + "start_index_map[0] only"; return nullptr; } @@ -277,27 +276,27 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( // arrays from an array of size [7,4,6]. We check that condition down below: for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { - if (i != dim_numbers.elided_window_dims(0) && - source->shape().dimensions(i) != window_bounds[i]) { - VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + if (i != dim_numbers.collapsed_slice_dims(0) && + source->shape().dimensions(i) != slice_sizes[i]) { + VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i << "] != source->shape().dimensions(" << i << ") -- " - << source->shape().dimensions(i) << " vs. " << window_bounds[i] - << " with dim_numbers.elided_window_dims(0) = " - << dim_numbers.elided_window_dims(0); + << source->shape().dimensions(i) << " vs. " << slice_sizes[i] + << " with dim_numbers.collapsed_slice_dims(0) = " + << dim_numbers.collapsed_slice_dims(0); return nullptr; } } - int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + int64 source_dim = dim_numbers.start_index_map(0); std::vector output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } } if (auto* indexed = dynamic_cast(source)) { - if (c_linear_search(indexed->output_dims(), source_dim)) { + if (absl::c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -314,8 +313,8 @@ namespace { // Returns an index into `values` such that the product of the range // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. -int64 FindSuffixWithProduct(ArraySlice values, int64 product) { - DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); +int64 FindSuffixWithProduct(absl::Span values, int64 product) { + DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; int64 i; @@ -343,7 +342,8 @@ struct ReshapePassthroughDimPair { // The returned vector of pairs is sorted in both the result_dim and the // operand_dim components. std::vector ComputeReshapePassthroughDimPairs( - ArraySlice operand_shape, ArraySlice result_shape) { + absl::Span operand_shape, + absl::Span result_shape) { // A reshape can be seen as an index mapping from output index to input index: // // (i_0, ..., i_n) = f(o_0, ..., o_m) @@ -378,8 +378,8 @@ std::vector ComputeReshapePassthroughDimPairs( CHECK_NE(candidate_operand_dim, 0) << "result_dim = " << result_dim << ", result_subarray_size = " << result_subarray_size - << ", result_shape = [" << Join(result_shape, ",") << "]" - << ", operand_shape = [" << Join(operand_shape, ",") << "]"; + << ", result_shape = [" << StrJoin(result_shape, ",") << "]" + << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -389,26 +389,27 @@ std::vector ComputeReshapePassthroughDimPairs( result_subarray_size *= result_shape[result_dim]; } - c_reverse(result); + absl::c_reverse(result); if (VLOG_IS_ON(3)) { std::vector result_strings; - c_transform(result, std::back_inserter(result_strings), - [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat(value.result_dim, "->", - value.operand_dim); - }); - VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" - << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; + absl::c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return absl::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" + << StrJoin(result_shape, ",") << "] passthrough indices are [" + << StrJoin(result_strings, ",") + << "] (legend: `result`->`operand`)"; } - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.result_dim < rhs.result_dim; })); - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.operand_dim < rhs.operand_dim; })); @@ -419,43 +420,44 @@ std::vector ComputeReshapePassthroughDimPairs( // Return true if `dim` is stated as an passthrough operand dim in // `passthrough_dims`. bool IsReshapePassthroughOperandDim( - ArraySlice passthrough_dims, int64 dim) { - return c_any_of(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == dim; - }); + absl::Span passthrough_dims, int64 dim) { + return absl::c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); } // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( - ArraySlice passthrough_dims, int64 operand_dim) { - auto it = c_find_if(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == operand_dim; - }); + absl::Span passthrough_dims, + int64 operand_dim) { + auto it = absl::c_find_if( + passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); CHECK(it != passthrough_dims.end()); return it->result_dim; } -int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, - ArraySlice result_shape, - int64 source_passthrough_dim) { +int64 FindSourcePositionForPassthroughResultDim( + absl::Span operand_shape, absl::Span result_shape, + int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" - << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, - operand_shape.end(), 1, std::multiplies()); + operand_shape.end(), 1LL, std::multiplies()); return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } Shape StripDegenerateDimensions(const Shape& shape) { DimensionVector new_dims; - c_copy_if(shape.dimensions(), std::back_inserter(new_dims), - [](int64 dim) { return dim != 1; }); + absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); return ShapeUtil::MakeShape(shape.element_type(), new_dims); } }; // namespace @@ -498,7 +500,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { if (shape.dimensions(i) == 1) { degenerate_dims_seen++; - } else if (ArrayContains(operand->output_dims(), i)) { + } else if (absl::c_linear_search(operand->output_dims(), i)) { new_output_dims.push_back(i - degenerate_dims_seen); } } @@ -518,8 +520,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( } StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims) { + ScalarIndexedArray* operand, absl::Span degenerate_dims) { if (degenerate_dims.empty()) { return operand; } @@ -531,7 +532,7 @@ StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( // element is true iff the i'th component of the result index is an output // index. - gtl::InlinedVector output_dims_bitvector( + absl::InlinedVector output_dims_bitvector( operand->shape().dimensions_size()); for (int64 output_dim : operand->output_dims()) { output_dims_bitvector[output_dim] = true; @@ -553,8 +554,8 @@ StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( }(); DimensionVector new_result_shape_dims; - c_copy(operand->shape().dimensions(), - std::back_inserter(new_result_shape_dims)); + absl::c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); for (int64 degenerate_dim : degenerate_dims) { InsertAt(&new_result_shape_dims, degenerate_dim, 1); } @@ -695,8 +696,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( operand_dim); }; - if (!c_all_of(scalar_indexed->output_dims(), - is_reshape_passthrough_operand_dim)) { + if (!absl::c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { VLOG(3) << "Not all output dims are passthrough dims " << ToString(scalar_indexed); return nullptr; @@ -735,11 +736,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( // operand = s32[3,5,2] constant({...}) // indices = s32[7] parameter(0) // gather = s32[3,2,7] gather(operand, indices), - // output_window_dims={0,1}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0,1}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3,1,2} + // slice_sizes={3,1,2} // reshape = s32[6,7] reshape(gather) // // In this case the gather maps to: @@ -754,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( if (source_dim_for_new_scalar_indexed_node == -1) { VLOG(3) << "Could not compute the source dim for the new scalar indexed " "node: scalar_indexed_source_shape = [" - << Join(scalar_indexed_source_shape.dimensions(), ",") + << StrJoin(scalar_indexed_source_shape.dimensions(), ",") << "] and new_scalar_indexed_source_shape = [" - << Join(new_scalar_indexed_source_shape, ",") << "]"; + << StrJoin(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -764,8 +765,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, - std::multiplies()), + CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL, + std::multiplies()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); CHECK(IsReshapePassthroughOperandDim( @@ -781,9 +782,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( }; std::vector output_dims_for_new_scalar_indexed_node; - c_transform(scalar_indexed->output_dims(), - std::back_inserter(output_dims_for_new_scalar_indexed_node), - map_passthrough_operand_dim_to_result_dim); + absl::c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, TakeOwnership(scalar_indexed->literal().Reshape( @@ -872,13 +873,14 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, return nullptr; } - ArraySlice broadcast_dims = broadcast_instr->dimensions(); + absl::Span broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { - return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; // All of the output dims must be "broadcasted" dims for the other operand. - if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + if (!absl::c_all_of(scalar_indexed_const->output_dims(), + is_broadcasted_dim)) { return nullptr; } @@ -894,7 +896,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // The scalar-indexed node "removes" the source dim and "inserts" the output // dims. We do the opposite here to undo the scalar-indexed operation. - ArraySlice output_dims = scalar_indexed_const->output_dims(); + absl::Span output_dims = scalar_indexed_const->output_dims(); for (int64 i = output_dims.size() - 1; i >= 0; --i) { CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); EraseAt(&simulated_index, output_dims[i]); @@ -916,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) TF_ASSIGN_OR_RETURN( - std::unique_ptr inner_broadcast_result, + Literal inner_broadcast_result, broadcast_const_operand->literal().Broadcast( scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); @@ -926,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + opcode, scalar_indexed_const->literal(), inner_broadcast_result))); } else { TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + opcode, inner_broadcast_result, scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct(literal_for_new_source); @@ -970,15 +972,15 @@ namespace { // Returns the non-contracting non-batch dimension (as per `contracting_dims` // and `batch_dims`) if there is exactly one, otherwise returns nullopt. -gtl::optional GetOnlyNonContractingNonBatchDim( - int64 rank, ArraySlice contracting_dims, - ArraySlice batch_dims) { - gtl::optional result; +absl::optional GetOnlyNonContractingNonBatchDim( + int64 rank, absl::Span contracting_dims, + absl::Span batch_dims) { + absl::optional result; for (int64 dim = 0; dim < rank; dim++) { - if (!ArrayContains(contracting_dims, dim) && - !ArrayContains(batch_dims, dim)) { + if (!absl::c_linear_search(contracting_dims, dim) && + !absl::c_linear_search(batch_dims, dim)) { if (result.has_value()) { - return gtl::nullopt; + return absl::nullopt; } result = dim; } @@ -995,10 +997,10 @@ gtl::optional GetOnlyNonContractingNonBatchDim( // `contracting_dims` and `batch_dims` are the contracting and batch dimensions // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( - tensorflow::StringPiece tag, - Analysis::ScalarIndexedConstantArray* indexed_array, - ArraySlice contracting_dims, ArraySlice batch_dims) { - gtl::optional non_contracting_non_batch_dim = + absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, + absl::Span contracting_dims, + absl::Span batch_dims) { + absl::optional non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { @@ -1029,7 +1031,8 @@ bool CanFoldDotIntoIndexedArray( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1044,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, lhs->literal(), *rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, lhs->literal(), *rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting LHS // dimension "went". @@ -1062,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1078,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( new_dim_numbers.set_rhs_contracting_dimensions( 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, *lhs->literal(), rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, *lhs->literal(), rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting RHS // dimension "went". @@ -1094,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( } StatusOr IndexedArrayAnalysis::ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, - Array* rhs) { + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant @@ -1118,6 +1124,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( dynamic_cast(lhs)) { if (auto* rhs_constant = dynamic_cast(rhs)) { return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + precision_config, lhs_indexed_array, rhs_constant); } } @@ -1125,7 +1132,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( if (auto* rhs_indexed_array = dynamic_cast(rhs)) { if (auto* lhs_constant = dynamic_cast(lhs)) { - return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, + precision_config, lhs_constant, rhs_indexed_array); } } @@ -1133,7 +1141,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( return nullptr; } -tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { +absl::string_view IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index e923dc39f7f464a8d3c400294499a6f5efda3991..3e238f97a03fb71cddf59da69b0389731314ff49 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -188,9 +188,7 @@ class IndexedArrayAnalysis { // `output_dims` are the dimensions in the output array that are being used // to compute an index into the `indices` array. See the class // documentation and the overview for more details. - tensorflow::gtl::ArraySlice output_dims() const { - return output_dims_; - } + absl::Span output_dims() const { return output_dims_; } private: explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, @@ -265,19 +263,21 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, Array* source, - Array* indices); + absl::Span slice_sizes, Array* source, Array* indices); StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs); StatusOr ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs); StatusOr ComputeArrayForDot(const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another @@ -303,7 +303,7 @@ class IndexedArrayAnalysis { // G1 = [Arr[i] for i in I2] StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape); + absl::Span output_dims, Shape shape); // Reshapes a scalar-indexed node to remove the degenerate dimensions in its // output. The result is always a scalar-indexed node. @@ -313,8 +313,7 @@ class IndexedArrayAnalysis { // Reshapes a scalar-indexed node such that the result has the degenerate // dimensions `degenerate_dims`. The result is always a scalar-indexed node. StatusOr ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims); + ScalarIndexedArray* operand, absl::Span degenerate_dims); StatusOr FoldReshapeOfGather( const Shape& shape, ScalarIndexedConstantArray* operand); @@ -348,30 +347,28 @@ class IndexedArrayAnalysis { } } - Literal* TakeOwnership(std::unique_ptr literal) { + Literal* TakeOwnership(Literal literal) { owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } - StatusOr TakeOwnership( - StatusOr> literal_or_error) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - std::move(literal_or_error)); + StatusOr TakeOwnership(StatusOr literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } std::vector> owned_tensors_; - std::vector> owned_literals_; + std::vector owned_literals_; tensorflow::gtl::FlatMap cache_; }; // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to // unconditionally add to the regular HLO pass pipeline. -class IndexedArrayAnalysisPrinterPass : public HloPassInterface { +class IndexedArrayAnalysisPrinterPass : public HloModulePass { public: - tensorflow::StringPiece name() const override; + absl::string_view name() const override; StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 5f4b42799b1c26ea544f9d4447cc45b5ae9d5a48..2d03aebc1aca4c55cca588072233b7a18e70a306 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -82,11 +82,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -102,11 +102,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -122,11 +122,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} } )"; @@ -141,11 +141,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,2}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0,2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3,1} + slice_sizes={1,3,1} } )"; @@ -160,11 +160,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2,3] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} } )"; @@ -179,11 +179,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,2} + slice_sizes={1,2} } )"; @@ -199,17 +199,17 @@ ENTRY main { indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT gather_b = s32[2,3] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -228,17 +228,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[2] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -256,17 +256,17 @@ ENTRY main { indices_a = s32[2] parameter(1) indices_b = s32[5,7] parameter(2) gather_a = s32[2,6] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} } )"; @@ -284,17 +284,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[4,8] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=2, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -312,11 +312,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2] reshape(gather) } )"; @@ -333,11 +333,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,7] reshape(gather) } )"; @@ -358,11 +358,11 @@ ENTRY main { {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[5,7] parameter(0) gather = s32[5,2,6,7] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,2,6} + slice_sizes={1,2,6} ROOT reshape = s32[5,3,4,7] reshape(gather) } )"; @@ -381,11 +381,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -408,14 +408,14 @@ ENTRY main { operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) - g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, - elided_window_dims={0}, gather_dims_to_operand_dims={0}, - index_vector_dim=2, window_bounds={1,3} + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, + index_vector_dim=2, slice_sizes={1,3} i.1 = s64[1] parameter(1) - g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, - elided_window_dims={1}, gather_dims_to_operand_dims={1}, - index_vector_dim=1, window_bounds={1,1,3} + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2}, + collapsed_slice_dims={1}, start_index_map={1}, + index_vector_dim=1, slice_sizes={1,1,3} ROOT reshape = s32[1,3]{1,0} reshape(g.1) } @@ -441,11 +441,11 @@ ENTRY main { operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -469,11 +469,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1,2}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={1,1,6} + slice_sizes={1,1,6} ROOT reshape = s32[1,1,1,6] reshape(gather) } )"; @@ -500,11 +500,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), - output_window_dims={2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,5,6] reshape(gather) } )"; @@ -530,11 +530,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,2,3] reshape(gather) } )"; @@ -562,11 +562,11 @@ ENTRY main { {{1,2},{3,4},{5,6},{7,8},{9,10}}}) indices = s32[7] parameter(0) gather = s32[3,2,7] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0,1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1,2} + slice_sizes={3,1,2} ROOT reshape = s32[6,7] reshape(gather) } )"; @@ -594,11 +594,11 @@ ENTRY main { {{1},{2},{3},{4}}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6,1] gather(operand, indices), - output_window_dims={1,3}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,3}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4,1} + slice_sizes={1,4,1} ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) } )"; @@ -623,20 +623,20 @@ ENTRY main { operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT tanh = f32[5,4] tanh(gather) } )"; AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( (scalar-indexed-const (constant f32[3,4] f32[3,4] { - { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, - { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, - { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } + { 0.761594, 0.964028, 0.995055, 0.999329 }, + { 0.761594, 0.995055, 0.964028, 0.999329 }, + { 0.999329, 0.995055, 0.964028, 0.761594 } }) %indices 0->[0]))"); } @@ -650,11 +650,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -678,11 +678,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) } )"; @@ -706,11 +706,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) } )"; @@ -733,11 +733,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -760,11 +760,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -808,11 +808,11 @@ ENTRY main { dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -835,11 +835,11 @@ ENTRY main { dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} } )"; @@ -863,11 +863,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -892,11 +892,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; @@ -921,11 +921,11 @@ ENTRY main { dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} @@ -952,11 +952,11 @@ ENTRY main { dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 5c193fceb984448cf0532d7e1010281268614293..5fd779ebf9b59e34a0844cc3a898bb72ce6044ee 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h index a523811f6c141a7dc24b1c88897d82d046aa1a2d..e20af08fb7329c3646903761ee081e421daa5712 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/inliner.h @@ -24,10 +24,10 @@ namespace xla { // A pass which performs inlining. Which can result, for example, in functions // that were previously being mapped by Map instead directly applied to the // forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloPassInterface { +class Inliner : public HloModulePass { public: ~Inliner() override = default; - tensorflow::StringPiece name() const override { return "inline"; } + absl::string_view name() const override { return "inline"; } // Run inlining on the given computation. Returns whether the computation was // changed. diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 32937b33b3737482f07d4c7607f7f1c5c183a56b..7e967f035c1054e22d10790188a5a232ca8e751a 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloTestBase; +using InlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(InlinerTest, MapMax) { @@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } // Test that `constant` function is changed to `broadcast`. @@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e2191aedb7f03ad4a956d9f4b8b1bfd4f5b5e08e..3fdc2cee9aad0fe70f66920f757ee5c52bba711f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -120,6 +122,8 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -129,7 +133,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: - case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kLog1p: case HloOpcode::kMap: @@ -170,7 +173,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { }); return std::count_if(hlo->operands().begin(), hlo->operands().end(), [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { return false; } if (operand->opcode() == HloOpcode::kConstant && @@ -188,13 +192,13 @@ bool InstructionFusion::CanFuseOnAllPaths( if (consumer == producer) { return true; } - if (!consumer->IsFusable()) { + if (!consumer->IsFusible()) { return false; } for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter - // whether it's fusable. + // whether it's fusible. if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } @@ -204,7 +208,7 @@ bool InstructionFusion::CanFuseOnAllPaths( } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. + // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { @@ -215,8 +219,8 @@ bool InstructionFusion::CanFuseOnAllPaths( } InstructionFusion::HloInstructionSet -InstructionFusion::ComputeGloballyUnfusable( - tensorflow::gtl::ArraySlice post_order) { +InstructionFusion::ComputeGloballyUnfusible( + absl::Span post_order) { // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. @@ -269,19 +273,19 @@ InstructionFusion::ComputeGloballyUnfusable( // all of its consumers on all paths. // // That means, that for: - // A --> B (fusable) - // \-> C (non-fusable) + // A --> B (fusible) + // \-> C (non-fusible) // A will be not allowed to be fused into B, as it cannot be fused into C. // // Similarly, for: // A -------------> B // \-> C -> D -/ // If: - // - A is fusable into B and C, and D is fusable into B - // - C is *not* fusable into D + // - A is fusible into B and C, and D is fusible into B + // - C is *not* fusible into D // A will be not allowed to be fused into B, as it cannot be fused via // all paths. - if (producer->IsFusable() && + if (producer->IsFusible() && CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } @@ -292,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusable( return do_not_duplicate; } +namespace { + +// A FusionQueue that uses reverse post order. +// +// We want to be able to remove arbitrary instructions from the post order and +// also compare positions of instructions in the post order. To make this +// possible, create vector of instructions in post order and create a map from +// HloInstruction* to the instruction's index in the vector. An instruction is +// "removed" from the vector by setting it's element to nullptr. +class ReversePostOrderFusionQueue : public FusionQueue { + public: + explicit ReversePostOrderFusionQueue(HloComputation* computation) { + post_order_ = computation->MakeInstructionPostOrder(); + + for (size_t i = 0; i < post_order_.size(); ++i) { + InsertOrDie(&post_order_index_, post_order_[i], i); + } + } + + std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() override { + // Instructions are "removed" from the post order by nulling out the element + // in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + while (!post_order_.empty() && post_order_.back() == nullptr) { + post_order_.pop_back(); + } + if (post_order_.empty()) { + return std::pair>{nullptr, {}}; + } + // We want to iterate in reverse post order, so remove from the back of the + // vector. + HloInstruction* instruction = post_order_.back(); + post_order_.pop_back(); + + CHECK(instruction != nullptr); + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index_.erase(instruction); + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid creating duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the order + // they appear int the queue. In the example, this ensures that B will be + // considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the queue. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) { + continue; + } + sorted_operand_numbers.push_back(i); + } + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return ( + FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + FindOrDie(post_order_index_, instruction->mutable_operand(j))); + }); + return std::make_pair(instruction, sorted_operand_numbers); + } + + void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) override { + // Fusing an instruction into a fusion instruction can change the operand + // set of the fusion instruction. For simplicity just re-enqueue the + // instruction and reconsider it for further fusion in the next iteration. + InsertOrDie(&post_order_index_, fusion, post_order_.size()); + post_order_.push_back(fusion); + } + + void RemoveInstruction(HloInstruction* instruction) override { + post_order_[FindOrDie(post_order_index_, instruction)] = nullptr; + post_order_index_.erase(instruction); + } + + private: + std::vector post_order_; + tensorflow::gtl::FlatMap post_order_index_; +}; + +} // namespace + +std::unique_ptr InstructionFusion::GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer) { + return absl::make_unique(computation); +} + StatusOr InstructionFusion::Run(HloModule* module) { VLOG(2) << "Before instruction fusion:"; XLA_VLOG_LINES(2, module->ToString()); @@ -303,116 +439,36 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = computation_->ComputeReachability(); - // We want to be able to remove arbitrary instructions from the post order - // and also compare positions of instructions in the post order. To make - // this possible, create vector of instructions in post order and create a - // map from HloInstruction* to the instruction's index in the vector. An - // instruction is "removed" from the vector by setting it's element to - // nullptr. - std::vector post_order = - computation_->MakeInstructionPostOrder(); - - tensorflow::gtl::FlatMap post_order_index; - for (size_t i = 0; i < post_order.size(); ++i) { - InsertOrDie(&post_order_index, post_order[i], i); - } - - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + auto fusion_queue = + GetFusionQueue(computation_, [&](HloInstruction* producer) { + return do_not_duplicate.count(producer) > 0; + }); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the // fusion instruction. - while (!post_order.empty()) { - // We want to iterate in reverse post order, so remove from the back of - // the vector. - HloInstruction* instruction = post_order.back(); - post_order.pop_back(); - - // Instructions are "removed" from the post order by nulling out the - // element in the vector, so if the pointer is null, continue to the next - // instruction in the sort. + while (true) { + auto next_entry = + fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); + auto instruction = next_entry.first; if (instruction == nullptr) { - continue; + break; } - // Remove instruction from the index map to ensure the vector and map stay - // consistent. - post_order_index.erase(instruction); - - if (!instruction->IsFusable() && + if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } - // Consider each operand of this instruction for fusion into this - // instruction. We want to consider the operands in a particular order to - // avoid creating duplicate instruction clones in the fusion instruction. - // For example, consider the following expression: - // - // A = ... - // B = op(A) - // C = op(A, B) - // - // If we are considering the operands of C for fusion into C. We might - // fuse A or B first. If we fuse A first, we get: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // C' = op(A', B) } - // - // Where A' and C' are clones of A and C, respectively. Now only B is an - // operand of the fusion instruction C_fusion, so then we fuse B: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // B' = op(A) - // C' = op(A', B') } - // - // Now A is an operand of C_fusion again, so we then fuse A (again!): - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // A" = .. - // B' = op(A") - // C' = op(A', B') } - // - // We prevent this duplication by considering the operands in the reverse - // order they appear in the instruction post order. In the example, this - // ensures that B will be considered before A. - // - // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers; - sorted_operand_numbers.reserve(instruction->operands().size()); - for (int i = 0; i < instruction->operands().size(); ++i) { - // This will happen if we have two possible instructions to fuse the - // same operand into; once the operand is fused into one instruction, - // the other instruction will get a new get-tuple-element as its - // operand, which is not in the post-order index. - // TODO(tjoerg): Look into fusing past these multi-output fuse points. - if (post_order_index.find(instruction->mutable_operand(i)) == - post_order_index.end()) { - continue; - } - sorted_operand_numbers.push_back(i); - } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { - // Instructions with higher indices in the post order come - // first. - return ( - FindOrDie(post_order_index, instruction->mutable_operand(i)) > - FindOrDie(post_order_index, instruction->mutable_operand(j))); - }); + std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (!operand->IsFusable()) { + if (!operand->IsFusible()) { continue; } @@ -422,32 +478,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { // TODO(tjoerg): Consider making multi-output fusion the default. if (ShouldFuse(instruction, i) && do_not_duplicate.count(operand) == 0) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = FuseIntoMultiOutput(operand, instruction); } else { continue; } - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); + fusion_queue->OnFusingInstruction(fusion_instruction, operand, + instruction); changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting its - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - + do_not_duplicate.erase(operand); + // Operand is now dead. Remove from queue. + fusion_queue->RemoveInstruction(operand); // Remove from computation. TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + + if (fusion_instruction != instruction) { + do_not_duplicate.erase(instruction); + } break; } } @@ -496,7 +551,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return c_any_of( + return absl::c_any_of( consumer->operands(), [&](const HloInstruction* consumer_operand) { // The fusion algorithm traverses the HLO graph in reverse post order. // Thus `cosumers` is visited before its operands (including diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f73ca9adf768ed26f9ec9f162e01b7b160f50daf..7e1196fb7fbeb4072929773b2161fe28233d73d9 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,19 +24,46 @@ limitations under the License. namespace xla { +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in // code generation. Derived classes define ShouldFuse method to select which // instructions to fuse. -class InstructionFusion : public HloPassInterface { +class InstructionFusion : public HloModulePass { public: explicit InstructionFusion( std::function is_expensive, bool may_duplicate = true) : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} ~InstructionFusion() override = default; - tensorflow::StringPiece name() const override { return "fusion"; } + absl::string_view name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). @@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface { static bool IsExpensive(const HloInstruction& instruction); protected: + // Returns a FusionQueue that implements custom order of instructions being + // fused. The default implementation processes consumers in reverse post + // order. + virtual std::unique_ptr GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer); + // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. // Derived classes should define this method to specify which instructions @@ -122,8 +156,8 @@ class InstructionFusion : public HloPassInterface { // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - HloInstructionSet ComputeGloballyUnfusable( - tensorflow::gtl::ArraySlice post_order); + HloInstructionSet ComputeGloballyUnfusible( + absl::Span post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9e7a15f0330d3f06779c850a4b575f84fe0b9505..da1ad90959dc0ab1a840b3390281ce9d4999651e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloComputation::Builder builder(TestName()); auto shape = ShapeUtil::MakeShape(F32, {16, 16}); auto param0 = @@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we do not duplicate the add, as we cannot fuse through the rng. // // p0 -> add -------------------------> sub @@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); // A variant of the above that allows the algorithm to put add2 into the set - // of unfusable ops to short-circuit the decision whether add1 should be fused + // of unfusible ops to short-circuit the decision whether add1 should be fused // into sub2. // // /---------------\ diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 8652599dc6d48ff8c2aaa703fead161f891a57d1..146c9052f10cca8b199a480491d9a672d8bebdff 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -12,12 +12,11 @@ cc_library( srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], deps = [ - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -32,8 +31,6 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", @@ -54,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains compiler registration ) @@ -79,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", @@ -91,6 +88,8 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -116,5 +115,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c..bb69cb9c47ff2c7de8d13832c4b8e6216c62da73 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -69,8 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module), - xla::MakeUnique()); + absl::make_unique( + std::move(hlo_module), absl::make_unique()); return std::move(executable); } @@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { - return xla::MakeUnique(); + return absl::make_unique(); }); xla::ComputationPlacer::RegisterComputationPlacer( se::interpreter::kXlaInterpreterPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 8d40c08d555a232b7cf3b81cc0f9970804c2f896..a06d6113e84630df14ff68280c248cccb9afaf06 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" @@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {} StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); @@ -73,30 +73,29 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. - std::vector> arg_literals; + std::vector arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - std::unique_ptr result_literal; + Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate>( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( + *computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, transfer_manager->AllocateScopedShapedBuffer( - result_literal->shape(), run_options->allocator(), + result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), *result_literal, result)); + run_options->stream(), result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -111,7 +110,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 91d8148d26dc8eddbafdaf4870d9efbb73a12816..3b1ebce0c75457d65e6834c809fe488a9c4a159a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 9b109022fbfc698f7dadc678ef837da270a5e74a..fbb99457847dca69a1901006d5d8ff713882f918 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_description.h" @@ -47,7 +47,7 @@ limitations under the License. namespace stream_executor { namespace interpreter { -using Args = tensorflow::gtl::ArraySlice; +using Args = absl::Span; class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: @@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { return false; } diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index d27cd7502f10a1f615fc5b0d610acafdf55e3e43..7955ee5cf37f3fa45b942d8ab05a60076857dc6c 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager() static std::unique_ptr CreateInterpreterTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h index 2b44f308218e2f61f08012769246b8a0e9639822..b732230fdd88b694f21ad5bc03d373331f8fb8f9 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/core/platform/macros.h" @@ -33,4 +33,4 @@ class InterpreterTransferManager : public GenericTransferManager { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 42c2c28997d5f3b02f1fe4effca164c893e4071d..c9b40d3c6195f80a19272a0d98890049d02315b9 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -17,13 +17,14 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" -#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -70,15 +71,15 @@ port::StatusOr XlaInterpreterPlatform::GetExecutor( port::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { - auto executor = MakeUnique( - this, MakeUnique(config.plugin_config)); + auto executor = absl::make_unique( + this, absl::make_unique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ port::error::INTERNAL, - port::Printf( + absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; + config.ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index b5a9d6e8e7d66ae0c560226a79578d85eaf55644..082bf8bffed484244139e79f4d3fe30ca091d8ac 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -26,9 +26,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -48,21 +52,11 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { -// For now moving only one API here, but we should have a single top level -// anonymous namespace, instead of three or four spread all over this file. -namespace { - -} // namespace - std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -77,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, } string BufferLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", - buffer_->ToString().c_str(), - LayoutUtil::HumanString(layout_).c_str()); + return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(), + LayoutUtil::HumanString(layout_)); } OperandLayoutConstraint::OperandLayoutConstraint( @@ -98,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint( } string OperandLayoutConstraint::ToString() const { - return tensorflow::strings::Printf( - "OperandLayoutConstraint %s, operand %lld: %s", - instruction_->name().c_str(), operand_no_, - shape_layout_.ToString().c_str()); + return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s", + instruction_->name(), operand_no_, + shape_layout_.ToString()); } string ResultLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("ResultLayoutConstraint: %s", - shape_layout_.ToString().c_str()); + return absl::StrFormat("ResultLayoutConstraint: %s", + shape_layout_.ToString()); } LayoutConstraints::LayoutConstraints( @@ -137,7 +129,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( } auto& buffer_set = buffer_sets_cache_ - .emplace(instruction, MakeUnique()) + .emplace(instruction, absl::make_unique()) .first->second; const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); points_to_set.ForEachElement( @@ -174,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Layout of buffer %s cannot be constrained because buffer is not " "array-shaped, has shape: %s", - buffer.ToString().c_str(), - ShapeUtil::HumanString(buffer.shape()).c_str()); + buffer.ToString(), ShapeUtil::HumanString(buffer.shape())); } TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); @@ -191,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", - buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint.layout()).c_str(), - LayoutUtil::HumanString(layout).c_str()); + buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()), + LayoutUtil::HumanString(layout)); } iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } else { @@ -227,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } if (curr_shape_layout->mandatory()) { return FailedPrecondition( - "Operand %lld of instruction %s already has a layout constraint " + "Operand %d of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", - operand_no, instruction->name().c_str(), - curr_shape_layout->shape_layout().ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + operand_no, instruction->name(), + curr_shape_layout->shape_layout().ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } } @@ -240,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, // layouts beyond this immediate use and is complicated to handle. if (OperandBufferForwarded(instruction, operand_no)) { return FailedPrecondition( - "Cannot constraint layout of operand %lld of instruction %s " + "Cannot constraint layout of operand %d of instruction %s " "because instruction forwards operand's LogicalBuffer(s)", - operand_no, instruction->name().c_str()); + operand_no, instruction->name()); } auto key = std::make_pair(instruction, operand_no); @@ -284,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", - computation_->name().c_str(), curr_shape_layout->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + computation_->name(), curr_shape_layout->ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // New constraint matches existing constraint. Nothing to do. return Status::OK(); @@ -307,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout( if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { return FailedPrecondition( "Instruction %s of shape %s cannot be assigned incompatible layout %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // Create a BufferLayoutConstraint for each array shape in the output of the @@ -368,31 +357,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const { string LayoutConstraints::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", - computation_->name(), ":\n"); + absl::StrAppend(&output, "LayoutConstraints for computation ", + computation_->name(), ":\n"); for (auto* instruction : computation_->MakeInstructionPostOrder()) { - tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), - "\n"); + absl::StrAppend(&output, " ", instruction->ToShortString(), "\n"); for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { - tensorflow::strings::StrAppend( - &output, " operand (", i, - "): ", OperandLayout(instruction, i)->ToString(), "\n"); + absl::StrAppend(&output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { if (BufferLayout(*buffer) != nullptr) { - tensorflow::strings::StrAppend( - &output, " ", buffer->ToString(), " : ", - LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); + absl::StrAppend(&output, " ", buffer->ToString(), " : ", + LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); } } } if (ResultLayout() != nullptr) { - tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), - "\n"); + absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n"); } return output; } @@ -763,7 +748,7 @@ Status CheckParameterLayout(HloInstruction* parameter, return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", - parameter->ToString().c_str(), parameter_layout.ToString().c_str()); + parameter->ToString(), parameter_layout.ToString()); } return Status::OK(); } @@ -774,8 +759,8 @@ Status CheckConstantLayout(HloInstruction* constant) { constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", - constant->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); + constant->ToString(), + ShapeUtil::HumanStringWithLayout(constant->literal().shape())); } return Status::OK(); } @@ -870,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, ? instruction.sharding().GetSubSharding(instruction.shape(), index) : instruction.sharding(); // We propagate the sharding to the copied instruction only if it is a - // special sharding, like tiled ones, or special devices like the - // HostCompute module. + // special sharding, like tiled ones. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. auto device = sharding.UniqueDevice(); @@ -908,13 +892,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str(), - buffer->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction_subshape) - .c_str(), - ShapeUtil::HumanStringWithLayout(buffer->shape()) - .c_str()); + instruction->name(), absl::StrJoin(index, ","), + buffer->ToString(), + ShapeUtil::HumanStringWithLayout(instruction_subshape), + ShapeUtil::HumanStringWithLayout(buffer->shape())); } } } @@ -998,17 +979,18 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && + if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape())) { - // Assign operands the same layout as the instruction, so that + ShapeUtil::Rank(instruction->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + // Propagate the result layout to the operand layout if the instruction + // requires the same layout out for the result and the operand. + // + // For elementwise operations, using the same layout for the operands and + // the result also has the following benefits: // 1) the elementwise operation can reuse its operand's buffer, and // 2) the input and output elements can reuse the same linear index. - // - // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit - // from assigning the same layout to input and output. - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } if (instruction->opcode() == HloOpcode::kReshape) { @@ -1031,13 +1013,13 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(operand_shape.layout()); + return absl::make_unique(operand_shape.layout()); } if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } } auto aligned_operand_shape = @@ -1046,7 +1028,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } @@ -1062,7 +1044,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } return nullptr; @@ -1076,11 +1058,11 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( CHECK(ShapeUtil::IsArray(user->shape()) && ShapeUtil::IsArray(operand->shape())); - if (user->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { + if (!ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(user)) { // Assign users the same layout as the operand. - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } if (user->opcode() == HloOpcode::kReshape) { @@ -1103,13 +1085,13 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(output_shape.layout()); + return absl::make_unique(output_shape.layout()); } if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } auto aligned_user_shape = @@ -1118,7 +1100,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } } @@ -1134,7 +1116,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } return nullptr; @@ -1385,7 +1367,7 @@ StatusOr InferArrayLayout( // This should not happen because we've assigned layouts to all // instructions preceding this one. return InternalError("LogicalBuffer %s does not have a layout", - source_buffer->ToString().c_str()); + source_buffer->ToString()); } if (first_buffer_layout == nullptr) { @@ -1400,9 +1382,8 @@ StatusOr InferArrayLayout( return FailedPrecondition( "Array at index {%s} in instruction %s aliases buffers %s " "and %s which have different layouts", - tensorflow::str_util::Join(index, ",").c_str(), - instruction->name().c_str(), source_buffers[0]->ToString().c_str(), - source_buffer->ToString().c_str()); + absl::StrJoin(index, ","), instruction->name(), + source_buffers[0]->ToString(), source_buffer->ToString()); } } @@ -1563,14 +1544,14 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // and the computation result. The latter two are specified in // computation_layout, so we only need to keep the existing layouts for // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. + // layout assignment pass that may accidentally use the existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction // present in the IR before layout assignment is a bug. return InternalError( "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + instruction->ToString()); } if (instruction->opcode() != HloOpcode::kInfeed) { LayoutUtil::ClearLayout(instruction->mutable_shape()); @@ -1822,6 +1803,107 @@ StatusOr LayoutAssignment::Run(HloModule* module) { return true; } +bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCustomCall: + case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFft: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kLt: + case HloOpcode::kMap: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kReverse: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return true; + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopy: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kAfterAll: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return false; + } +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index f9e8dbea2f8aa224318adf3cf4b5e493792d3093..e29c199c42a4878daaf2eeb86b6909d6d3ff920e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -281,7 +281,7 @@ class ChannelLayoutConstraints { // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. -class LayoutAssignment : public HloPassInterface { +class LayoutAssignment : public HloModulePass { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. @@ -297,12 +297,17 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} - tensorflow::StringPiece name() const override { return "layout-assignment"; } + absl::string_view name() const override { return "layout-assignment"; } // Assign layouts to the given module. Returns whether the module was changed // (any layouts were changed). StatusOr Run(HloModule* module) override; + // Returns true if the instruction requires that operands with the same rank + // as the output have to have the same layout as the output. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a16fa75e3032cfa4257d9b5608dd176fdb4ddbdb..752a61476dd7892a2b7f531c4057015f48fc4758 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -34,13 +35,12 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -49,7 +49,7 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloTestBase { +class LayoutAssignmentTest : public HloVerifiedTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout, @@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase { EXPECT_IS_OK(layout_assignment.Run(module).status()); } - std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + std::vector LayoutOf(HloModule* module, absl::string_view name) { auto minor_to_major = FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); @@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); - Shape ashape = constant_literal1->shape(); + Shape ashape = constant_literal1.shape(); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(constant_literal1))); @@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module.get()) + .Run(module) .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. @@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, f32_4, "param")); auto broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_34, param, {3})); + HloInstruction::CreateBroadcast(f32_34, param, {1})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); auto broadcast2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); auto module = CreateNewModule(); @@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + EXPECT_IS_OK(layout_assignment.Run(module).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); - module = + std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); EXPECT_EQ(Status::OK(), backend() .compiler() - ->RunBackend(std::move(module), + ->RunBackend(std::move(compiled_module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .status()); @@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(&module(), &computation_layout); - EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module.get()).status(); + Status error_status = layout_assignment.Run(module).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -851,14 +851,151 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(&module(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); - EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::GetSubshape( - FindInstruction(module.get(), "send")->shape(), {0}), - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); + EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); +} + +TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopySliceOperandToAvoidImplicitLayoutChange + + ENTRY CopySliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]} + ROOT add0 = f32[3,4]{1,0} add(par0,slice0) + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, op::Add(op::Parameter(), + op::Slice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy))))); +} + +TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyDSliceOperandToAvoidImplicitLayoutChange + + ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + par2 = s32[2] parameter(2) + dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); +} + +TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyConcatOperandToAvoidImplicitLayoutChange + + ENTRY CopyConcatOperandToAvoidImplicitLayoutChange { + par0 = f32[3,8]{1,0} parameter(0) + par1 = f32[3,5]{0,1} parameter(1) + par2 = f32[3,3]{1,0} parameter(2) + concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2), + dimensions={1} + ROOT add0 = f32[3,8]{1,0} add(par0,concat0) + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::Concatenate(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); +} + +TEST_F(LayoutAssignmentTest, + ConvolutionOperandWithImplicitLayoutChangeNotCopied) { + const char* module_str = R"( + HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied + + ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied { + par0 = f32[128,3,230,230]{2,3,1,0} parameter(0) + par1 = f32[7,7,3,64]{3,2,0,1} parameter(1) + ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1), + window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01, + feature_group_count=1 + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); +} + +TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { + const char* module_str = R"( + HloModule PropagatingLayoutFromResultToOperand + + ENTRY PropagatingLayoutFromResultToOperand { + par0 = f32[4,5]{1,0} parameter(0) + ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]} + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); + EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), + op::ShapeWithLayout(shape_copy)))); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index cdd3daf73b8ac1a4d1ec3c81224c2c0bfe8e5811..540bbb7c7a74f65ab70f4c6704d6600db2adbb60 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -69,6 +70,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -88,6 +91,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -103,6 +109,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -120,6 +128,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -133,9 +142,7 @@ cc_library( ":llvm_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@llvm//:core", @@ -159,6 +166,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -193,7 +201,10 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@llvm//:core", + "@llvm//:support", ], ) @@ -208,6 +219,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -219,7 +231,7 @@ cc_library( deps = [ ":llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -230,6 +242,7 @@ cc_library( hdrs = ["buffer_assignment_util.h"], deps = [ "//tensorflow/compiler/xla/service:buffer_assignment", + "@com_google_absl//absl/strings", ], ) @@ -242,3 +255,12 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "ir_builder_mixin", + srcs = [], + hdrs = ["ir_builder_mixin.h"], + deps = [ + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index fe9eab93aae95557e3ee27a64c09b78f37ac2348..8d9fa99d82b4e49b653d9f05cc9baa5e3fdcefa6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace llvm_ir { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index fe5ec1cc66d06e85ce70625ef7cf764a37b29166..b6ae4932f5707f1d15af1e09a735a7de2e48fac5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -61,7 +61,7 @@ ENTRY while3 { ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %buffer_table, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index 4eb5d9fb4750927ca189e02f312b2d6be7fdd418..bdce4a171b8a58f617f1d56e6cf6db5354846703 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "absl/strings/str_cat.h" namespace xla { namespace llvm_ir { @@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName( c = '_'; } } - return tensorflow::strings::StrCat("buffer_for_", instr_name); + return absl::StrCat("buffer_for_", instr_name); } const Literal& LiteralForConstantAllocation( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 27fbb11e2ede66a1268e7e949634b2c7d29cbc1c..cc2e862f2eb9a49099c5f90efe1b29fb77c8f106 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* b) { + absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. @@ -99,10 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); } -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b) { +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the @@ -130,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace( // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -174,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( } Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( @@ -184,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace( } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 3502577d236a099e0b721b98217b758696966821..fb3e4eb97cae06f2a0c87dd7118b8332048df56e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -63,26 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // Emits IR for running the given dynamic-update-slice op in-place -- that is, // where the input and output buffers share the same slice, so we can simply // modify the input/output buffer without touching any of the other elements. -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b); +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits // (sequential) code for a fusion node that does the dynamic-update-slice in // place. Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 72ede377e1a505d5e4916915e18827e1a0f3fdf9..b606c993a2d58a6d177af10de7b214de130c2279 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement( return Unimplemented( "GetTupleElement fusion currently only supports" " parameter operands, but found operand: %s", - operand->name().c_str()); + operand->name()); } // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( @@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { } Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); std::vector operand_elemental_ir_types; for (HloInstruction* operand : operands) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 30471480c4fb3ce3bf3226a28e9d2ffa79ae5f29..44d21fa750a532633f46614002d59c90fc0b5d40 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; - FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, + FusedIrEmitter(absl::Span parameter_arrays, ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), tiled_parameter_info_(nullptr), @@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { private: // Arrays of parameters of fusion instruction - tensorflow::gtl::ArraySlice parameter_arrays_; + absl::Span parameter_arrays_; const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 2b6caee6aa72f426cf85c8c56c3ef500ff8c5d3d..67f7423121177e2ca1e3384341dad2644c8f5e34 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, Delinearize(&multidim_, linear, shape, b); } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), linear_(linear), @@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, << " should have a layout."; } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, const Shape& shape, llvm::IRBuilder<>* b) : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), @@ -147,16 +147,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = - Index(tensorflow::gtl::ArraySlice( - multidim_, common_factors[k].second, + Index(absl::Span(multidim_).subspan( + common_factors[k].second, common_factors[k + 1].second - common_factors[k].second), index_type_) - .Linearize( - tensorflow::gtl::ArraySlice( - AsInt64Slice(output_shape.dimensions()), - common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second), - builder); + .Linearize(AsInt64Slice(output_shape.dimensions()) + .subspan(common_factors[k].second, + common_factors[k + 1].second - + common_factors[k].second), + builder); // Delinearizes logical_linear_index for the source array in row-major // collapsed order. The first rank-1 indices are the remainder of the // linear index by each dimension size. @@ -185,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, - llvm::IRBuilder<>* builder) const { + const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const { Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; @@ -208,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice( IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); @@ -257,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { int64 rank = ShapeUtil::Rank(operand_shape); std::vector source_index(rank); @@ -322,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( return Index(source_index, linear, operand_shape); } -llvm::Value* IrArray::Index::Linearize( - tensorflow::gtl::ArraySlice dimensions, - llvm::IRBuilder<>* builder) const { +llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, + llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. CHECK_EQ(size(), dimensions.size()); @@ -342,9 +339,9 @@ llvm::Value* IrArray::Index::Linearize( return logical_linear_index; } -llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { +llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, + llvm::IRBuilder<>* b, + absl::string_view name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value @@ -402,7 +399,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { + absl::string_view name) const { llvm::Value* element_address = EmitArrayElementAddress(index, b, name); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 28ca793e3eeaed86664bfa6aa859a38f2c4dc6f3..f4b05f29c38529b3cce81b4c8ee6fae5c00cafcc 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -19,13 +19,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -69,7 +70,7 @@ class IrArray { // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim, + explicit Index(absl::Span multidim, llvm::Type* index_ty = nullptr) : multidim_(multidim.begin(), multidim.end()) { if (size() == 0) { @@ -81,7 +82,7 @@ class IrArray { } } CHECK_NE(index_type_, nullptr); - CHECK(c_all_of(multidim, [&](llvm::Value* v) { + CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { return index_type_ == v->getType(); })); } @@ -98,14 +99,14 @@ class IrArray { // that it indexes into. // // Precondition: "shape" has a layout. - Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* b); + Index(absl::Span multidim, const Shape& shape, + llvm::IRBuilder<>* b); // Constructs an index from both a multi-dimensional index and a linear // index. "shape" has the same meaning as that in the constructor that takes // only a linear index. - Index(tensorflow::gtl::ArraySlice multidim, - llvm::Value* linear, const Shape& shape); + Index(absl::Span multidim, llvm::Value* linear, + const Shape& shape); const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } @@ -144,17 +145,15 @@ class IrArray { // by starting indices `starts` and stride values `strides`. // // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, - tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, + Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfTranspose( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a bitcast from `operand_shape` // to `shape`, returns the source index. @@ -163,14 +162,13 @@ class IrArray { // Given that "this" is the target index of a broadcast from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfBroadcast( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. - llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, + llvm::Value* Linearize(absl::Span dimensions, llvm::IRBuilder<>* builder) const; llvm::Type* GetType() const { return index_type_; } @@ -240,7 +238,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -254,7 +252,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h new file mode 100644 index 0000000000000000000000000000000000000000..abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -0,0 +1,400 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ + +#include "llvm/IR/IRBuilder.h" + +namespace xla { + +// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods +// into a class. Intended to be used as a CRTP base class, like: +// +// class MyIrEmitter : public IrBuilderMixin { +// llvm::IRBuilder<>* builder() { return builder_; } +// +// void EmitFoo(HloInstruction* foo) { +// Add(Mul(...), FPToUI(...)); +// } +// }; + +template +class IrBuilderMixin { + protected: + template + llvm::Value* Add(Args&&... args) { + return mixin_builder()->CreateAdd(std::forward(args)...); + } + + template + llvm::LoadInst* AlignedLoad(Args&&... args) { + return mixin_builder()->CreateAlignedLoad(std::forward(args)...); + } + + template + llvm::StoreInst* AlignedStore(Args&&... args) { + return mixin_builder()->CreateAlignedStore(std::forward(args)...); + } + + template + llvm::AllocaInst* Alloca(Args&&... args) { + return mixin_builder()->CreateAlloca(std::forward(args)...); + } + + template + llvm::Value* And(Args&&... args) { + return mixin_builder()->CreateAnd(std::forward(args)...); + } + + template + llvm::Value* AtomicCmpXchg(Args&&... args) { + return mixin_builder()->CreateAtomicCmpXchg(std::forward(args)...); + } + + template + llvm::Value* AtomicRMW(Args&&... args) { + return mixin_builder()->CreateAtomicRMW(std::forward(args)...); + } + + template + llvm::Value* BitCast(Args&&... args) { + return mixin_builder()->CreateBitCast(std::forward(args)...); + } + + template + llvm::Value* Br(Args&&... args) { + return mixin_builder()->CreateBr(std::forward(args)...); + } + + llvm::CallInst* Call(llvm::Value* callee, + llvm::ArrayRef args = llvm::None, + const llvm::Twine& name = "", + llvm::MDNode* fp_math_tag = nullptr) { + return mixin_builder()->CreateCall(callee, args, name, fp_math_tag); + } + + template + llvm::BranchInst* CondBr(Args&&... args) { + return mixin_builder()->CreateCondBr(std::forward(args)...); + } + + template + llvm::Value* ConstInBoundsGEP1_32(Args&&... args) { + return mixin_builder()->CreateConstInBoundsGEP1_32( + std::forward(args)...); + } + + template + llvm::Value* FAdd(Args&&... args) { + return mixin_builder()->CreateFAdd(std::forward(args)...); + } + + template + llvm::Value* FMul(Args&&... args) { + return mixin_builder()->CreateFMul(std::forward(args)...); + } + + llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateGEP(ptr, idx_list, name); + } + + template + llvm::Value* ICmpEQ(Args&&... args) { + return mixin_builder()->CreateICmpEQ(std::forward(args)...); + } + + template + llvm::Value* ICmpNE(Args&&... args) { + return mixin_builder()->CreateICmpNE(std::forward(args)...); + } + + template + llvm::Value* ICmpULE(Args&&... args) { + return mixin_builder()->CreateICmpULE(std::forward(args)...); + } + + template + llvm::Value* ICmpULT(Args&&... args) { + return mixin_builder()->CreateICmpULT(std::forward(args)...); + } + + llvm::Value* InBoundsGEP(llvm::Value* ptr, + llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name); + } + + llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateExtractValue(agg, idxs, name); + } + + llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val, + llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInsertValue(agg, val, idxs, name); + } + + template + llvm::Value* IntToPtr(Args&&... args) { + return mixin_builder()->CreateIntToPtr(std::forward(args)...); + } + + template + llvm::LoadInst* Load(Args&&... args) { + return mixin_builder()->CreateLoad(std::forward(args)...); + } + + template + llvm::CallInst* MemCpy(Args&&... args) { + return mixin_builder()->CreateMemCpy(std::forward(args)...); + } + + template + llvm::Value* Mul(Args&&... args) { + return mixin_builder()->CreateMul(std::forward(args)...); + } + + template + llvm::Value* NSWAdd(Args&&... args) { + return mixin_builder()->CreateNSWAdd(std::forward(args)...); + } + + template + llvm::Value* NSWMul(Args&&... args) { + return mixin_builder()->CreateNSWMul(std::forward(args)...); + } + + template + llvm::Value* NSWSub(Args&&... args) { + return mixin_builder()->CreateNSWSub(std::forward(args)...); + } + + template + llvm::Value* Or(Args&&... args) { + return mixin_builder()->CreateOr(std::forward(args)...); + } + + template + llvm::Value* PointerCast(Args&&... args) { + return mixin_builder()->CreatePointerCast(std::forward(args)...); + } + + template + llvm::Value* PtrToInt(Args&&... args) { + return mixin_builder()->CreatePtrToInt(std::forward(args)...); + } + + template + llvm::Value* SDiv(Args&&... args) { + return mixin_builder()->CreateSDiv(std::forward(args)...); + } + + template + llvm::Value* Select(Args&&... args) { + return mixin_builder()->CreateSelect(std::forward(args)...); + } + + template + llvm::Value* SRem(Args&&... args) { + return mixin_builder()->CreateSRem(std::forward(args)...); + } + + template + llvm::StoreInst* Store(Args&&... args) { + return mixin_builder()->CreateStore(std::forward(args)...); + } + + template + llvm::Value* UDiv(Args&&... args) { + return mixin_builder()->CreateUDiv(std::forward(args)...); + } + + template + llvm::Value* URem(Args&&... args) { + return mixin_builder()->CreateURem(std::forward(args)...); + } + + template + llvm::Value* VectorSplat(Args&&... args) { + return mixin_builder()->CreateVectorSplat(std::forward(args)...); + } + + template + llvm::Value* ZExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateZExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* AShr(Args&&... args) { + return mixin_builder()->CreateAShr(std::forward(args)...); + } + + template + llvm::Value* FCmpOEQ(Args&&... args) { + return mixin_builder()->CreateFCmpOEQ(std::forward(args)...); + } + + template + llvm::Value* FCmpOLT(Args&&... args) { + return mixin_builder()->CreateFCmpOLT(std::forward(args)...); + } + + template + llvm::Value* FCmpONE(Args&&... args) { + return mixin_builder()->CreateFCmpONE(std::forward(args)...); + } + + template + llvm::Value* FCmpUNE(Args&&... args) { + return mixin_builder()->CreateFCmpUNE(std::forward(args)...); + } + + template + llvm::Value* FDiv(Args&&... args) { + return mixin_builder()->CreateFDiv(std::forward(args)...); + } + + template + llvm::Value* FNeg(Args&&... args) { + return mixin_builder()->CreateFNeg(std::forward(args)...); + } + + template + llvm::Value* FPCast(Args&&... args) { + return mixin_builder()->CreateFPCast(std::forward(args)...); + } + + template + llvm::Value* FPToSI(Args&&... args) { + return mixin_builder()->CreateFPToSI(std::forward(args)...); + } + + template + llvm::Value* FPToUI(Args&&... args) { + return mixin_builder()->CreateFPToUI(std::forward(args)...); + } + + template + llvm::Value* FPTrunc(Args&&... args) { + return mixin_builder()->CreateFPTrunc(std::forward(args)...); + } + + template + llvm::Value* FRem(Args&&... args) { + return mixin_builder()->CreateFRem(std::forward(args)...); + } + + template + llvm::Value* FSub(Args&&... args) { + return mixin_builder()->CreateFSub(std::forward(args)...); + } + + template + llvm::Value* ICmpSGE(Args&&... args) { + return mixin_builder()->CreateICmpSGE(std::forward(args)...); + } + + template + llvm::Value* ICmpSLT(Args&&... args) { + return mixin_builder()->CreateICmpSLT(std::forward(args)...); + } + + template + llvm::Value* IntCast(Args&&... args) { + return mixin_builder()->CreateIntCast(std::forward(args)...); + } + + template + llvm::Value* LShr(Args&&... args) { + return mixin_builder()->CreateLShr(std::forward(args)...); + } + + template + llvm::Value* MemSet(Args&&... args) { + return mixin_builder()->CreateMemSet(std::forward(args)...); + } + + template + llvm::Value* Neg(Args&&... args) { + return mixin_builder()->CreateNeg(std::forward(args)...); + } + + template + llvm::Value* Not(Args&&... args) { + return mixin_builder()->CreateNot(std::forward(args)...); + } + + template + llvm::PHINode* PHI(Args&&... args) { + return mixin_builder()->CreatePHI(std::forward(args)...); + } + + template + llvm::Value* RetVoid(Args&&... args) { + return mixin_builder()->CreateRetVoid(std::forward(args)...); + } + + template + llvm::Value* SExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateSExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* Shl(Args&&... args) { + return mixin_builder()->CreateShl(std::forward(args)...); + } + + template + llvm::Value* SIToFP(Args&&... args) { + return mixin_builder()->CreateSIToFP(std::forward(args)...); + } + + template + llvm::Value* Sub(Args&&... args) { + return mixin_builder()->CreateSub(std::forward(args)...); + } + + template + llvm::Value* Trunc(Args&&... args) { + return mixin_builder()->CreateTrunc(std::forward(args)...); + } + + template + llvm::Value* UIToFP(Args&&... args) { + return mixin_builder()->CreateUIToFP(std::forward(args)...); + } + + template + llvm::Value* Unreachable(Args&&... args) { + return mixin_builder()->CreateUnreachable(std::forward(args)...); + } + + template + llvm::Value* Xor(Args&&... args) { + return mixin_builder()->CreateXor(std::forward(args)...); + } + + private: + llvm::IRBuilder<>* mixin_builder() { + return static_cast(this)->builder(); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index b79567369aa532c4963e3941f6cb9844cd1476dd..bd0139f85b6a5c5dc23dad962263038451921e65 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { return If(b_->CreateICmpSLT(start, end), [&]() -> Status { @@ -30,7 +30,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { @@ -56,7 +56,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::If( - tensorflow::StringPiece name, llvm::Value* condition, + absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); @@ -70,7 +70,7 @@ Status KernelSupportLibrary::If( void KernelSupportLibrary::EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, + absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index b00f903d56a83c5b76188007702470c44c55c213..43fec311f150d6054f6ad24f99db332f90ff94a3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ #include +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { // A thin wrapper around llvm_loop.h to make code generating structured control @@ -49,13 +49,13 @@ class KernelSupportLibrary { // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { @@ -67,7 +67,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + Status For(absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { @@ -77,7 +77,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), @@ -99,13 +99,13 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator); - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& @@ -129,7 +129,7 @@ class KernelSupportLibrary { peel_first_iteration, for_body_generator); } - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function& @@ -140,7 +140,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { return For(name, start, end, step, @@ -151,7 +151,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { ForReturnVoid(name, start, end, step, @@ -162,8 +162,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, @@ -173,8 +172,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, start, end, llvm::ConstantInt::get(start->getType(), step), @@ -182,7 +180,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { return For(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -190,7 +188,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -203,7 +201,7 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(tensorflow::StringPiece name, llvm::Value* condition, + Status If(absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() -> Status { return Status::OK(); }); @@ -222,7 +220,7 @@ class KernelSupportLibrary { IfReturnVoid("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + void IfReturnVoid(absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() { }) { @@ -237,7 +235,7 @@ class KernelSupportLibrary { })); } - using ArgumentVector = tensorflow::gtl::ArraySlice; + using ArgumentVector = absl::Span; // Generates the following control flow structure: // @@ -259,13 +257,13 @@ class KernelSupportLibrary { // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, ArgumentVector arguments, + absl::string_view kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function& kernel_body_generator) { @@ -278,7 +276,7 @@ class KernelSupportLibrary { static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { @@ -296,4 +294,4 @@ class KernelSupportLibrary { }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 35b394127288d816952b48c84b193257bab0bcda..e5fbdbd51b8a9aa14decadedd1eeb3bdbf831738 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -28,7 +28,7 @@ namespace { // Returns the indices of the first elements of all consecutive subarrays of the // given array. For example: // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { +std::vector ConsecutiveSegments(absl::Span xs) { std::vector is = {0}; for (size_t i = 1; i < xs.size(); ++i) { if (1 != xs[i] - xs[i - 1]) { @@ -40,8 +40,7 @@ std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { // Merges the sequences of dimensions of the given shape which start at the // given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, - const Shape& shape) { +Shape MergeDimensions(absl::Span segs, const Shape& shape) { std::vector dimensions; for (size_t i = 1; i <= segs.size(); ++i) { dimensions.push_back(std::accumulate( @@ -55,10 +54,10 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, } } // namespace -tensorflow::gtl::optional > FindTranspose021( - const Shape& a, const Shape& b) { +absl::optional > FindTranspose021(const Shape& a, + const Shape& b) { if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } std::vector perm(a.dimensions().size()); @@ -88,7 +87,7 @@ tensorflow::gtl::optional > FindTranspose021( return dims_021; } - return tensorflow::gtl::nullopt; + return absl::nullopt; } IrArray::Index GetUnreducedOutputIndex( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index ccb9b8ba3e6b0079664f2da92ce67224e176fa1d..5ea05b3188a1c0881e4c0c41625d530aff1b1205 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -36,8 +36,8 @@ namespace llvm_ir { // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the // reduced shape of `b` or the 0-2-1 shape. -tensorflow::gtl::optional > FindTranspose021(const Shape& a, - const Shape& b); +absl::optional > FindTranspose021(const Shape& a, + const Shape& b); // Return the unreduced output index corresponding to the given reduced output // index. @@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex( // for 021 transpose. class TiledParameterInfo { public: - TiledParameterInfo(tensorflow::gtl::ArraySlice param_buffers, + TiledParameterInfo(absl::Span param_buffers, llvm::Value* y, llvm::Value* x) : param_buffers_(param_buffers), y_(y), x_(x) {} @@ -67,7 +67,7 @@ class TiledParameterInfo { private: // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr // if the parameter is not tiled. - tensorflow::gtl::ArraySlice param_buffers_; + absl::Span param_buffers_; // The y coordinate within a tile. llvm::Value* y_; // The x coordinate within a tile. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index ba7f94834c7fd04d97cec012537244323308b8ce..219a9f221fbd116cdfbaf17985e21d82aefd079d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -25,19 +26,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, +ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) - : prefix_(std::string(prefix)), - suffix_(std::string(suffix)), + : prefix_(prefix), + suffix_(suffix), start_index_(start_index), end_index_(end_index), step_(step), @@ -46,9 +45,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, - UnrollMode unroll_mode, bool prevent_vectorization) { + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, + bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); @@ -168,16 +167,16 @@ std::vector ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { return result; } -string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { +string ForLoop::GetQualifiedName(absl::string_view name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } -llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, +llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b) { return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, +std::unique_ptr ForLoopNest::AddLoop(absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, @@ -186,12 +185,9 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, unroll_mode, prevent_vectorization); } -std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* stride, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); @@ -216,7 +212,7 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -227,7 +223,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -238,22 +234,22 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { std::vector dimensions(ShapeUtil::Rank(shape)); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); } IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, - tensorflow::StringPiece suffix) { + const Shape& shape, absl::Span dimensions, + absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension))); index[dimension] = loop->GetIndVarValue(); } return index; @@ -261,7 +257,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix) { + absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost // loops are added first. Add loops in major-to-minor order, and skip the // 'dimension_to_skip' dimension. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index a4fed5c8dc55d38d25031252e3960404a5bf84e6..ac3bba3c9fd6a9eb4e7822474963fcc5a394baf7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -19,15 +19,15 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -78,7 +78,7 @@ class ForLoop { // `unroll_mode` specifies the desired LLVM unrolling behavior for generated // loop. static std::unique_ptr EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -133,19 +133,18 @@ class ForLoop { // Allow ForLoopNest to call this private constructor. friend class ForLoopNest; - ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, + ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* b); - llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. - string GetQualifiedName(tensorflow::StringPiece name); + string GetQualifiedName(absl::string_view name); // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. @@ -182,9 +181,9 @@ class ForLoopNest { SetIndexType(index_ty); } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, + ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : name_(std::string(name)), + : name_(name), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), @@ -197,14 +196,14 @@ class ForLoopNest { // been added then emit loop inside the body of the last added loop. // unroll_mode is used to emit metadata that controls LLVM unrolling. std::unique_ptr AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -213,13 +212,13 @@ class ForLoopNest { // end index are constant. std::unique_ptr AddLoop( int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( - int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + int64 start_index, int64 end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -234,8 +233,7 @@ class ForLoopNest { // within the shape. One possible order for that sequence would be: // // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) - IrArray::Index AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix); + IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix); // Add a loop for each dimension in "dimensions". "suffix" is the // name suffix of the indvar and basic blocks in this new loop nest. @@ -244,8 +242,8 @@ class ForLoopNest { // size equals the rank of shape and there is a null for each // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, - tensorflow::StringPiece suffix); + const Shape& shape, absl::Span dimensions, + absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops // are constructed in major to minor dimension layout order. No loop is @@ -256,7 +254,7 @@ class ForLoopNest { // basic blocks) constructed by this method. IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix); + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e6126881af8b8123e08a4eaa934b52a7fd378ce6..1a53c026be340ca3bec3a49b11666d6124728130 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -61,7 +61,7 @@ string AsString(const std::string& str) { return string(str.data(), str.length()); } -llvm::StringRef AsStringRef(tensorflow::StringPiece str) { +llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } @@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b) { +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); @@ -262,15 +261,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment) { return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment) { +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment) { llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), @@ -285,7 +286,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( /*Context=*/b->getContext(), @@ -294,27 +295,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, /*InsertBefore*/ insert_before); } -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; if_data.if_block = b->GetInsertBlock(); if_data.true_block = - CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); + CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-false"), b) + emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { b->SetInsertPoint(if_data.if_block); - if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), b); + if_data.after_block = + CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b); b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), - AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); } // Our basic block should now end with an unconditional branch. Remove it; @@ -413,14 +412,14 @@ string IrName(string a) { return a; } -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) { +string IrName(absl::string_view a, absl::string_view b) { if (!a.empty() && !b.empty()) { - return IrName(tensorflow::strings::StrCat(a, ".", b)); + return IrName(absl::StrCat(a, ".", b)); } - return IrName(tensorflow::strings::StrCat(a, b)); + return IrName(absl::StrCat(a, b)); } -string IrName(const HloInstruction* a, tensorflow::StringPiece b) { +string IrName(const HloInstruction* a, absl::string_view b) { return IrName(a->name(), b); } @@ -556,7 +555,7 @@ std::map MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { +static string GetProcessUniqueIrFileName(absl::string_view prefix) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); @@ -584,18 +583,16 @@ Status DumpIRToDirectory(const string& directory_name, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. string unique_and_safe_file_name = GetProcessUniqueIrFileName( - tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); + absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); string ir_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); // For some models the embedded constants can be huge, so also dump the module // with the constants stripped to get IR that is easier to manipulate. string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( directory_name, ir_file_name, DumpModuleToString(llvm_module))); @@ -607,8 +604,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module) { + absl::string_view name, llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -638,7 +634,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { + if (!absl::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 09583985342033d486d50910b6f5ca732a9a3756..f59baff263fe7184c6b0821c9dbd9eee205586a6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -32,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace llvm { @@ -47,11 +47,11 @@ namespace llvm_ir { // Convert a std::string (used by LLVM's interfaces) to string. string AsString(const std::string& str); -// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both -// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a +// Convert a absl::string_view to a llvm::StringRef. Note: both +// absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(tensorflow::StringPiece str); +llvm::StringRef AsStringRef(absl::string_view str); template llvm::ArrayRef AsArrayRef(const std::vector& vec) { @@ -59,7 +59,7 @@ llvm::ArrayRef AsArrayRef(const std::vector& vec) { } template -llvm::ArrayRef AsArrayRef(const tensorflow::gtl::ArraySlice& slice) { +llvm::ArrayRef AsArrayRef(const absl::Span& slice) { return llvm::ArrayRef(slice.data(), slice.size()); } @@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module); // - removing all '%'s. // string IrName(string a); -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b); -string IrName(const HloInstruction* a, tensorflow::StringPiece b = ""); +string IrName(absl::string_view a, absl::string_view b); +string IrName(const HloInstruction* a, absl::string_view b = ""); // Removes special characters from a function name. // @@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b); +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise @@ -164,21 +163,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // This can be useful to avoid e.g. executing an alloca every time // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment = 0); +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created @@ -210,7 +211,7 @@ struct LlvmIfData { // Currently the insertion point of the builder must be a well-formed // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, @@ -285,8 +286,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module); + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 36f5fa195224c20e30a14f72b32eb42a681bb5e9..0dc120e0b0df47f261435f490a8459b49d989b53 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( } LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, llvm::IRBuilder<>* b) : body_emitter_(MakeBodyEmitterForMultiOutputFusion( target_element_generator, @@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. @@ -105,7 +105,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -122,7 +122,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, +Status LoopEmitter::EmitLoop(absl::string_view loop_name, llvm::Type* index_type) { if (index_type == nullptr) { index_type = b_->getInt64Ty(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index c4f5c82086ccfa233e0be118b1de10cce55a51b1..a537c00066b0a68404b142e91283510092b46e2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -53,8 +53,7 @@ class LoopEmitter { // This is used for multi-output fusion. target_element_generator must // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* b); + absl::Span target_arrays, llvm::IRBuilder<>* b); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; @@ -69,10 +68,10 @@ class LoopEmitter { } virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = "", + Status EmitLoop(absl::string_view loop_name = "", llvm::Type* index_type = nullptr); protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e546f5cc4ae305b40c1bdbcae090daadee11241b..944c79580c133906cd431722fd6b29e6aee5f918 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -29,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -42,7 +43,7 @@ namespace { void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, const IrArray::Index& compare_keys_index, const IrArray& keys_array, - const tensorflow::gtl::optional& values_array, + const absl::optional& values_array, llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) @@ -59,15 +60,39 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, SetToFirstInsertPoint(if_data.true_block, b); auto key1 = keys_array.EmitReadArrayElement(keys_index, b); auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); + auto compare_key1 = key1; + auto compare_key2 = key2; auto key_type = keys_array.GetShape().element_type(); + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the sort + // has a predictable behavior in the presence of NaNs. Rather than using + // floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } auto comparison = - primitive_util::IsFloatingPointType(key_type) - // TODO(b/26783907): Figure out how to handle NaNs. - ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) - : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) - ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - key2, key1); + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1); // If key2 < key1 auto if_smaller_data = EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); @@ -87,8 +112,8 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const tensorflow::gtl::optional& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + const absl::optional& values_array, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 8458744c6bc0e50a1c1cc8d3e66e29c7d4f74d73..527ed10374ce9482045a8459e38fd041e0e83001 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -31,8 +31,8 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const tensorflow::gtl::optional& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + const absl::optional& values_array, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 11ed6ee59f1bf8e7004b8bef7319b37ef41a304c..a60643bc754f896d096b3ca4e1216e77d7e384c6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, } } -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { auto* store = b->CreateStore( @@ -76,6 +75,16 @@ void EmitTuple(const IrArray& tuple, } } +void EmitTuple(const IrArray& tuple, absl::Span buffers, + llvm::IRBuilder<>* b, llvm::Module* module) { + std::vector buffer_ptrs; + buffer_ptrs.reserve(buffers.size()); + absl::c_transform( + buffers, std::back_inserter(buffer_ptrs), + [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); +} + llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, llvm::IRBuilder<>* b, llvm::Module* module) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index cf6bf5d0b14ba71cbed67f9a1dc728c0eef5e393..94340b91d8eeea1ba4681c2e49c0894eab2f6cc0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" // Utilities for emitting LLVM IR related to HLO tuples. @@ -65,8 +65,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, + llvm::IRBuilder<>* b, llvm::Module* module); + +// Similar to EmitTuple above, except that the output buffers are provided in +// the form of IrArray. +void EmitTuple(const IrArray& tuple, absl::Span buffers, llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5e02096ee501b23a7976a50f13bb7e7f3c5e2d34..0d0fb7946ae6815905491ca55652d7d0ab278a3c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -73,7 +74,7 @@ namespace { // If the parameter number is invalid for this computation, nullopt is // returned. When the return value has_value(), nullptr will never be // the held value. -tensorflow::gtl::optional ParameterMetadata( +absl::optional ParameterMetadata( const XlaComputation& computation, int parameter_number) { for (const HloComputationProto& comp : computation.proto().computations()) { if (comp.id() == computation.proto().entry_computation_id()) { @@ -81,14 +82,14 @@ tensorflow::gtl::optional ParameterMetadata( if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && instr.parameter_number() == parameter_number) { if (!instr.has_metadata()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } return &instr.metadata(); } } } } - return tensorflow::gtl::nullopt; + return absl::nullopt; } ExecutionOptions CreateExecutionOptions( @@ -140,7 +141,7 @@ ExecutionOptions CreateExecutionOptions( StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_program_shape()); @@ -149,7 +150,7 @@ StatusOr> LocalService::CompileExecutable( // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", + "Invalid number of arguments for computation: expected %d, got %u.", program_shape.parameters_size(), argument_layouts.size()); } @@ -158,7 +159,7 @@ StatusOr> LocalService::CompileExecutable( TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { - tensorflow::gtl::optional metadata = + absl::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); auto metadata_string = [&metadata]() -> string { if (!metadata.has_value()) { @@ -167,16 +168,15 @@ StatusOr> LocalService::CompileExecutable( CHECK(metadata.value() != nullptr); const OpMetadata& m = *metadata.value(); if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); + return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line()); } return ""; }; return InvalidArgument( "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); + metadata_string(), + ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(argument_shape)); } } if (build_options.result_layout() != nullptr) { @@ -214,7 +214,7 @@ StatusOr LocalService::GlobalDataToShapedBuffer( TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { return InvalidArgument( - "replica_number %d out of range; must be less than num_replicas = %zu.", + "replica_number %d out of range; must be less than num_replicas = %u.", replica_number, buffers.size()); } return buffers[replica_number]; diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 8f707ea9046a00a15cac469672a7a992f20bf483..3b4f0b50832d6d2b64528ffb63eb5c7375396aec 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -48,7 +48,7 @@ class LocalService : public Service { // compiler is responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); // Returns the device ordinal that corresponds to the given replica number. diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index c742d35a7bcafa66692195a513992c9cfbb39335..e1f56727bd209797c60f7b3f10c3e232992d01e0 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { string color_string; if (has_color()) { - color_string = tensorflow::strings::StrCat(" @", color().value()); + color_string = absl::StrCat(" @", color().value()); } - return tensorflow::strings::StrCat(instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "](#", id(), color_string, ")"); + return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","), + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index f9ba5a554740c9d4cc2643fe59d18ba76c30d03b..ceacab4ed7319527312a5a6ad715103b5bbaf40f 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -18,13 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index d631fb5ee42df6525681a5cd1fe1a8241824121d..ec52a24d782a44fda961feab3230886072e755c7 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -53,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() { // so reserve 10% more than the number of instructions to avoid frequent // resizes. logical_buffers_.clear(); - logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10); + logical_buffers_.reserve((module_->instruction_count() * 11) / 10); // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) @@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); logical_buffers_.emplace_back( - MakeUnique(instruction, index, next_buffer_id_)); + absl::make_unique(instruction, index, next_buffer_id_)); output_buffers_[std::make_pair(instruction, index)] = logical_buffers_.back().get(); diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..8269842426e3ee15ea974098a43fe7752c7614df --- /dev/null +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "absl/types/variant.h" +namespace xla { + +se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() { + if (HasOwnership()) { + return absl::get(mem_).AsDeviceMemoryBase(); + } else { + return absl::get(mem_); + } +} + +bool MaybeOwningDeviceMemory::HasOwnership() const { + return absl::holds_alternative(mem_); +} + +absl::optional MaybeOwningDeviceMemory::Release() { + if (!HasOwnership()) { + return {}; + } + OwningDeviceMemory result = std::move(absl::get(mem_)); + mem_ = result.AsDeviceMemoryBase(); + return absl::make_optional(std::move(result)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..82e7f1183c086437e10daea85ea99235b06cbb35 --- /dev/null +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ + +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +namespace xla { + +// MaybeOwningDeviceMemory represents either an owned or unowned device memory. +// Like std::variant. When the object goes +// output of scope, it will free the underlying memory if it owns it. +class MaybeOwningDeviceMemory { + public: + MaybeOwningDeviceMemory() = default; + explicit MaybeOwningDeviceMemory(OwningDeviceMemory owned) + : mem_(std::move(owned)) {} + explicit MaybeOwningDeviceMemory(se::DeviceMemoryBase unowned) + : mem_(unowned) {} + MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default; + ~MaybeOwningDeviceMemory() = default; + + MaybeOwningDeviceMemory& operator=(se::DeviceMemoryBase unowned) { + mem_ = unowned; + return *this; + } + + MaybeOwningDeviceMemory& operator=(OwningDeviceMemory owned) { + mem_ = std::move(owned); + return *this; + } + + MaybeOwningDeviceMemory& operator=(MaybeOwningDeviceMemory&&) = default; + + // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The + // caller of this function is *not* responsible for freeing the memory. + se::DeviceMemoryBase AsDeviceMemoryBase(); + + // Release the OwningDeviceMemory without freeing it, and moves the ownership + // of the memory buffer from the object to the caller. + // + // A nullopt is returned if the HasOwnership() == false; + absl::optional Release(); + + // Returns true if the device_memory has ownership over underlying memory. + bool HasOwnership() const; + + private: + absl::variant mem_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 4166ef5baf9c891968b584a0c498005e9ae87784..b9ec31c4977be0c31dfff01a0c495902191d7d5b 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() { void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip) { for (auto instr : instrs_to_update) { if (skip != nullptr && skip(instr)) { diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 0019cd725417d81900974b462c3b05075ce3e893..0344626b26b2cd1d659657c51636266706d17afb 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -44,13 +44,11 @@ namespace xla { // Note that the reachability map is updated based on the original computation. // This works because the reachability is monotonically increasing with // instruction fusion. -class MultiOutputFusion : public HloPassInterface { +class MultiOutputFusion : public HloModulePass { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} - tensorflow::StringPiece name() const override { - return "multi_output_fusion"; - } + absl::string_view name() const override { return "multi_output_fusion"; } // Run multi-output fusion on the given module. Returns whether the module // was changed. @@ -94,7 +92,7 @@ class MultiOutputFusion : public HloPassInterface { // Update the reachability map after fusing instr1 and instr2. void UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip = nullptr); // Hook for multi-output fusion along producer-consumer edges. @@ -104,17 +102,17 @@ class MultiOutputFusion : public HloPassInterface { // InstructionFusion instead. virtual bool DoProducerConsumerMultiOutputFusion(); - private: - // Update the internal data structures after instr1 and instr2 are fused into - // one fusion instruction. - void Update(HloInstruction* instr1, HloInstruction* instr2); - // Optimization fuel is a compiler debugging technique that makes an // optimization pass stop what it is doing after having made N changes to the // program, where N is the fuel. By varying N, this can be used to find the // first single change that makes a test fail. int64 fuel_; + private: + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + // Computation for the pass. HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f6e7578a89551ec2f23d4d8c8b488c3c10e0bf1c..ac2f79674feceff436c0e9c65338967f498e4473 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -38,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) { } /*static*/ string NameUniquer::GetSanitizedName(const string& name) { + if (name.empty()) { + return ""; + } string result = name; - CHECK(!result.empty()) << "name should not be empty"; char c = static_cast(result[0]); if (!isalpha(c) && c != '_') { result[0] = '_'; @@ -52,8 +55,8 @@ NameUniquer::NameUniquer(const string& separator) { return result; } -string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); +string NameUniquer::GetUniqueName(absl::string_view prefix) { + string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. @@ -63,20 +66,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); + } else { + // absl::SimpleAtoi may modify numeric_suffix even if it returns false. + numeric_suffix = 0; } } SequentialIdGenerator& id_generator = generated_names_[root]; numeric_suffix = id_generator.RegisterId(numeric_suffix); if (numeric_suffix == 0) { - return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) - : root; + return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root; } - tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + absl::StrAppend(&root, separator_, numeric_suffix); return root; } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4423d6106920eaeab830bd9dc08529ff409a5161..6dd89c240f81c9f0ccac66e50c7f244bfd5429f1 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -38,7 +38,7 @@ class NameUniquer { // Get a sanitized unique name in a string, with an optional prefix for // convenience. - string GetUniqueName(tensorflow::StringPiece prefix = ""); + string GetUniqueName(absl::string_view prefix = ""); // Sanitizes and returns the name. Unallowed characters will be replaced with // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ac6ea4c72f61a47726b3ae7dd000837d3fba1b93..380cde0e6a858c7800445be94bb08dc22f3e776a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,11 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/string_view.h" +#include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -116,15 +120,82 @@ namespace xla { // .WithOperand(1, Op(&c)) // .WithOperand(2, Op(&d)) // + +struct MatchOption { + // If true, actually capture matched item into the user pointer. + bool capture; +}; + template -bool Match(Value* value, const Pattern& pattern) { - return pattern.Match(value); +bool Match(Value* value, const Pattern& pattern, + MatchOption option = {/*.capture=*/true}) { + if (option.capture) { + auto new_option = option; + new_option.capture = false; + if (!pattern.Match(value, new_option)) { + return false; + } + } + return pattern.Match(value, option); } namespace match { namespace detail { +template +class AllOfPattern { + public: + explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + bool Match(Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + private: + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return std::get(patterns_).Match(item, option) && + MatchImpl(item, option, std::integral_constant()); + } + + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return true; + } + + std::tuple patterns_; +}; + +} // namespace detail + +// Returns a pattern that represents the conjunction of all input patterns. All +// patterns need to match in order to have the AllOf pattern match. +// +// TODO(timshen): Currently AllOf is still nested, e.g. AllOf, B> is +// not AllOf. We might want to flatten the AllOf type structure if the +// C++ compile error message gets annoying. +template +detail::AllOfPattern::type, Patterns...> AllOf( + const Patterns&... patterns) { + return detail::AllOfPattern::type, + Patterns...>(patterns...); +} + +namespace detail { + template class LayoutPattern; @@ -132,57 +203,61 @@ class LayoutPattern; // nullptr. class LayoutPatternBaseImpl { public: - bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout != nullptr; + } }; // A LayoutPattern implementation that matches only if the layout equals a // Layout proto. -template class LayoutPatternEqualImpl { public: - explicit constexpr LayoutPatternEqualImpl(const Previous& previous, - const ::xla::Layout* layout) - : previous_(previous), layout_(layout) {} + explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout) + : layout_(layout) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return LayoutUtil::Equal(*layout_, *layout); } private: - Previous previous_; const ::xla::Layout* layout_; }; // A LayoutPattern implementation that matches only if the layout has a given // format. -template class LayoutPatternFormatImpl { public: - explicit constexpr LayoutPatternFormatImpl(const Previous& previous, - Format format) - : previous_(previous), format_(format) {} + explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && layout->format() == format_; + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout->format() == format_; } private: - Previous previous_; Format format_; }; // A pattern that matches Layouts. template class LayoutPattern { + private: + template + LayoutPattern> + AppendImpl(NewImpl new_impl) const { + return LayoutPattern>( + AllOf(impl_, std::move(new_impl)), matched_layout_); + } + public: explicit constexpr LayoutPattern(const Impl& impl, LayoutType** matched_layout) : impl_(impl), matched_layout_(matched_layout) {} // Returns true and captures the layout iff it matches the pattern. - bool Match(const ::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(const ::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -191,9 +266,9 @@ class LayoutPattern { } // Returns true and captures the layout iff it matches the pattern. - bool Match(::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -203,24 +278,21 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. - constexpr LayoutPattern> EqualTo( - const ::xla::Layout* layout) const { - return LayoutPattern>( - LayoutPatternEqualImpl(impl_, layout), matched_layout_); + constexpr auto EqualTo(const ::xla::Layout* layout) const + -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { + return AppendImpl(LayoutPatternEqualImpl(layout)); } // Modifies the pattern to match only if the layout has a dense format. - constexpr LayoutPattern> - WithDenseFormat() const { - return LayoutPattern>( - LayoutPatternFormatImpl(impl_, DENSE), matched_layout_); + constexpr auto WithDenseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { + return AppendImpl(LayoutPatternFormatImpl(DENSE)); } // Modifies the pattern to match only if the layout has a sparse format. - constexpr LayoutPattern> - WithSparseFormat() const { - return LayoutPattern>( - LayoutPatternFormatImpl(impl_, SPARSE), matched_layout_); + constexpr auto WithSparseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) { + return AppendImpl(LayoutPatternFormatImpl(SPARSE)); } private: @@ -228,8 +300,72 @@ class LayoutPattern { LayoutType** matched_layout_; }; +template +class AnyOfPattern { + public: + explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant()); + } + + bool Match(Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant()); + } + + private: + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + auto new_option = option; + new_option.capture = false; + // Try to match the sub-pattern without capturing behavior. + if (std::get(patterns_).Match(item, new_option)) { + // Capture the branch. + if (option.capture) { + // TODO(timshen): Currently the behavior can be exponential. Optimize it + // with memoization or recording the matched sub-pattern index, if it + // takes too long to run. + // + // Specifically, the "memoization" approach is to create an empty + // container with the key (pattern, instruction), and value as whether + // matched or not. + // + // Alternatively, we may run the pattern matching with captures off, but + // instead record a "trace" somewhere, indicating how exactly the + // pattern matches the input. For example, the trace information for + // AnyOf will be a runtime number indicate which sub-pattern is matched. + // Then we run another pass to do captures only with the help of the + // trace. + bool ret = std::get(patterns_).Match(item, option); + DCHECK(ret); + } + return true; + } + return MatchImpl(item, option, std::integral_constant()); + } + + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return false; + } + + std::tuple patterns_; +}; + } // namespace detail +// Returns a pattern that represents the logical disjunction of the input +// patterns. The returned pattern matches from left to right, and stops on the +// first match. +template +detail::AnyOfPattern::type, Patterns...> AnyOf( + const Patterns&... patterns) { + return detail::AnyOfPattern::type, + Patterns...>(patterns...); +} + // Creates a layout pattern that will capture the matched layout in the // argument. inline constexpr detail::LayoutPattern class ShapePatternEqualImpl { public: - explicit constexpr ShapePatternEqualImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Equal(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape is compatible to // a Shape proto. -template class ShapePatternCompatibleImpl { public: - explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Compatible(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape has a given // element type. -template class ShapePatternElementTypeImpl { public: - explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, - PrimitiveType element_type) - : previous_(previous), element_type_(element_type) {} + explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type) + : element_type_(element_type) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && shape->element_type() == element_type_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return shape->element_type() == element_type_; } private: - Previous previous_; PrimitiveType element_type_; }; // A ShapePattern implementation that matches only if the shape is scalar. -template class ShapePatternIsScalarImpl { public: - explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsScalarImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsScalar(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is an array -template class ShapePatternIsArrayImpl { public: - explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsArrayImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsArray(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is a tuple. -template class ShapePatternIsTupleImpl { public: - explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsTupleImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsTuple(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape has a given // rank. -template class ShapePatternRankImpl { public: - explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) - : previous_(previous), rank_(rank) {} + explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Rank(*shape) == rank_; } private: - Previous previous_; int64 rank_; }; // A ShapePattern implementation that matches only if the shape has a layout // that matches a given pattern. -template +template class ShapePatternLayoutImpl { public: explicit constexpr ShapePatternLayoutImpl( - const Previous& previous, const LayoutPattern& layout) - : previous_(previous), layout_(layout) {} + : layout_(layout) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(&shape->layout()); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout(), option); } - bool Match(Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout()); + bool Match(Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout(), option); } private: - Previous previous_; LayoutPattern layout_; }; // A ShapePattern implementation that matches only if the shape has a subshape // that matches a given pattern. -template +template class ShapePatternSubshapeImpl { public: explicit ShapePatternSubshapeImpl( - const Previous& previous, ShapeIndexView index, + ShapeIndexView index, const ShapePattern& subshape) - : previous_(previous), index_(index), subshape_(subshape) {} + : index_(index), subshape_(subshape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); } - bool Match(::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + bool Match(::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), + option); } private: - Previous previous_; ShapeIndexView index_; ShapePattern subshape_; }; @@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl { // A pattern that matches Shapes. template class ShapePattern { + private: + template + ShapePattern> AppendImpl( + NewImpl new_impl) const { + return ShapePattern>( + AllOf(impl_, std::move(new_impl)), matched_shape_); + } + public: explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) : impl_(impl), matched_shape_(matched_shape) {} // Returns true and captures the shape iff it matches the pattern. - bool Match(const ::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -447,9 +564,9 @@ class ShapePattern { } // Returns true and captures the shape iff it matches the pattern. - bool Match(::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -459,108 +576,90 @@ class ShapePattern { // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. - constexpr ShapePattern> EqualTo( - const ::xla::Shape* shape) const { - return ShapePattern>( - ShapePatternEqualImpl(impl_, shape), matched_shape_); + constexpr auto EqualTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { + return AppendImpl(ShapePatternEqualImpl(shape)); } // Modifies the pattern to match only if the shape is compatible to the given // proto. The layout must outlive the returned pattern. - constexpr ShapePattern> - CompatibleTo(const ::xla::Shape* shape) const { - return ShapePattern>( - ShapePatternCompatibleImpl(impl_, shape), matched_shape_); + constexpr auto CompatibleTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { + return AppendImpl(ShapePatternCompatibleImpl(shape)); } // Modifies the pattern to match only if the shape has the given element type. - constexpr ShapePattern> - WithElementType(PrimitiveType element_type) const { - return ShapePattern>( - ShapePatternElementTypeImpl(impl_, element_type), matched_shape_); + constexpr auto WithElementType(PrimitiveType element_type) const + -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { + return AppendImpl(ShapePatternElementTypeImpl(element_type)); } // Modifies the pattern to match only if the shape is scalar. - constexpr ShapePattern> IsScalar() - const { - return ShapePattern>( - ShapePatternIsScalarImpl(impl_), matched_shape_); + constexpr auto IsScalar() const + -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { + return AppendImpl(ShapePatternIsScalarImpl()); } // Modifies the pattern to match only if the shape is an array. - constexpr ShapePattern> IsArray() - const { - return ShapePattern>( - ShapePatternIsArrayImpl(impl_), matched_shape_); + constexpr auto IsArray() const + -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { + return AppendImpl(ShapePatternIsArrayImpl()); } // Modifies the pattern to match only if the shape is a tuple. - constexpr ShapePattern> IsTuple() - const { - return ShapePattern>( - ShapePatternIsTupleImpl(impl_), matched_shape_); + constexpr auto IsTuple() const + -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { + return AppendImpl(ShapePatternIsTupleImpl()); } // Modifies the pattern to match only if the shape has the given rank. - constexpr ShapePattern> WithRank( - int64 rank) const { - return ShapePattern>( - ShapePatternRankImpl(impl_, rank), matched_shape_); + constexpr auto WithRank(int64 rank) const + -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { + return AppendImpl(ShapePatternRankImpl(rank)); } // Modifies the pattern to match only if the shape has a layout that matches // the given pattern. template - constexpr ShapePattern> - WithLayout(const LayoutPattern& layout) const { - return ShapePattern>( - ShapePatternLayoutImpl(impl_, layout), - matched_shape_); - } - - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - WithLayoutEqualTo(const ::xla::Layout* layout) const { + auto WithLayout(const LayoutPattern& layout) const + -> decltype(this->AppendImpl( + ShapePatternLayoutImpl(layout))) { + return AppendImpl(ShapePatternLayoutImpl(layout)); + } + + constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const + -> decltype(this->WithLayout(Layout().EqualTo(layout))) { return WithLayout(Layout().EqualTo(layout)); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - IsDenseArray() const { + constexpr auto IsDenseArray() const + -> decltype(this->WithLayout(Layout().WithDenseFormat())) { return WithLayout(Layout().WithDenseFormat()); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - IsSparseArray() const { + constexpr auto IsSparseArray() const + -> decltype(this->WithLayout(Layout().WithSparseFormat())) { return WithLayout(Layout().WithSparseFormat()); } // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template + auto WithSubshape(ShapeIndexView index, + const ShapePattern& subshape) + const -> decltype(this->AppendImpl( + ShapePatternSubshapeImpl(index, + subshape))) { + return AppendImpl( + ShapePatternSubshapeImpl(index, subshape)); + } + ShapePattern> - WithSubshape(ShapeIndexView index, - const ShapePattern& subshape) const { - return ShapePattern< - ShapeType, ShapePatternSubshapeImpl>( - ShapePatternSubshapeImpl(impl_, index, - subshape), - matched_shape_); - } - - ShapePattern>> + AllOfPattern>>> WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, ShapePattern( @@ -568,9 +667,12 @@ class ShapePattern { .EqualTo(shape)); } - ShapePattern>> + ShapePattern>>> WithSubshapeCompatibleTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, @@ -611,159 +713,169 @@ class HloInstructionPattern; // instruction is not nullptr. class HloInstructionPatternBaseImpl { public: - bool Match(const ::xla::HloInstruction* inst) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { return inst != nullptr; } }; // An HloInstructionPattern implementation that matches only if the instruction // has a given name. -template class HloInstructionPatternNameImpl { public: - explicit HloInstructionPatternNameImpl(const Previous& previous, - tensorflow::StringPiece name) - : previous_(previous), name_(name) {} + explicit HloInstructionPatternNameImpl(absl::string_view name) + : name_(name) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->name() == name_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->name() == name_; } private: - Previous previous_; - tensorflow::StringPiece name_; + absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. -template class HloInstructionPatternOpcodeImpl { public: - explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, - HloOpcode opcode, + explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode, bool invert) - : previous_(previous), opcode_(opcode), invert_(invert) {} + : opcode_(opcode), invert_(invert) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return (invert_ ^ (inst->opcode() == opcode_)); } private: - Previous previous_; HloOpcode opcode_; bool invert_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a shape that matches a given pattern. -template +template class HloInstructionPatternShapeImpl { public: explicit constexpr HloInstructionPatternShapeImpl( - const Previous& previous, const ShapePattern& shape) - : previous_(previous), shape_(shape) {} + const ShapePattern& shape) + : shape_(shape) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(&inst->shape()); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(&inst->shape(), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(inst->mutable_shape(), option); } private: - Previous previous_; ShapePattern shape_; }; // An HloInstructionPattern implementation that matches only if the instruction // has an operand that matches a given pattern. -template +template class HloInstructionPatternOperandImpl { public: explicit constexpr HloInstructionPatternOperandImpl( - const Previous& previous, int64 operand_index, + int64 operand_index, const HloInstructionPattern& operand) - : previous_(previous), operand_index_(operand_index), operand_(operand) {} + : operand_index_(operand_index), operand_(operand) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_)); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_), option); } private: - Previous previous_; int64 operand_index_; HloInstructionPattern operand_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. -template class HloInstructionPatternFusionKindImpl { public: explicit constexpr HloInstructionPatternFusionKindImpl( - const Previous& previous, ::xla::HloInstruction::FusionKind kind) - : previous_(previous), kind_(kind) {} + ::xla::HloInstruction::FusionKind kind) + : kind_(kind) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } private: - Previous previous_; ::xla::HloInstruction::FusionKind kind_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a kGetTupleElement with a particular tuple index. -template class HloInstructionPatternTupleIndexImpl { public: - explicit constexpr HloInstructionPatternTupleIndexImpl( - const Previous& previous, int64 tuple_index) - : previous_(previous), tuple_index_(tuple_index) {} + explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index) + : tuple_index_(tuple_index) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } private: - Previous previous_; int64 tuple_index_; }; +template +class HloPredicatePatternImpl { + public: + explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + + bool Match(const ItemType* item, MatchOption option) const { + return pred_(item); + } + + bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + + private: + Predicate pred_; +}; + +struct PatternFriend; + // A pattern that matches HloInstructions. template class HloInstructionPattern { + private: + template + HloInstructionPattern> + AppendImpl(NewImpl new_impl) const { + return HloInstructionPattern< + HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( + AllOf(impl_, std::move(new_impl)), matched_inst_); + } + public: explicit constexpr HloInstructionPattern(const Impl& impl, HloInstructionType** matched_inst) : impl_(impl), matched_inst_(matched_inst) {} // Returns true and captures the instruction iff it matches the pattern. - bool Match(const ::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -772,9 +884,9 @@ class HloInstructionPattern { } // Returns true and captures the instruction iff it matches the pattern. - bool Match(::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -783,102 +895,87 @@ class HloInstructionPattern { } // Modifies the pattern to match only if the instruction has the given name. - HloInstructionPattern> - WithName(tensorflow::StringPiece name) const { - return HloInstructionPattern>( - HloInstructionPatternNameImpl(impl_, name), matched_inst_); + auto WithName(absl::string_view name) const + -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { + return AppendImpl(HloInstructionPatternNameImpl(name)); } // Modifies the pattern to match only if the instruction has the given opcode. - constexpr HloInstructionPattern> - WithOpcode(HloOpcode opcode) const { - return HloInstructionPattern>( - HloInstructionPatternOpcodeImpl(impl_, opcode, false), - matched_inst_); + auto WithOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + false))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } // Modifies the pattern to match only if the instruction does not have the // given opcode. - constexpr HloInstructionPattern> - WithoutOpcode(HloOpcode opcode) const { - return HloInstructionPattern>( - HloInstructionPatternOpcodeImpl(impl_, opcode, true), - matched_inst_); + auto WithoutOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + true))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } // Modifies the pattern to match only if the instruction is a constant. - constexpr HloInstructionPattern> - IsConstant() const { + constexpr auto IsConstant() const + -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction is not a constant. - constexpr HloInstructionPattern> - IsNonConstant() const { + constexpr auto IsNonConstant() const + -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { return WithoutOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction has a shape that // matches the given pattern. template - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl> - WithShape(const ShapePattern& shape) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl>( - HloInstructionPatternShapeImpl(impl_, - shape), - matched_inst_); + constexpr auto WithShape(const ShapePattern& shape) + const -> decltype(this->AppendImpl( + HloInstructionPatternShapeImpl(shape))) { + return AppendImpl( + HloInstructionPatternShapeImpl(shape)); } // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl> - WithOperand( + constexpr auto WithOperand( int64 operand_index, - const HloInstructionPattern& operand) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl>( - HloInstructionPatternOperandImpl( - impl_, operand_index, operand), - matched_inst_); + const HloInstructionPattern& operand) const + -> decltype(this->AppendImpl( + HloInstructionPatternOperandImpl( + operand_index, operand))) { + return AppendImpl( + HloInstructionPatternOperandImpl( + operand_index, operand)); } // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. - constexpr HloInstructionPattern> - WithFusionKind(HloInstruction::FusionKind kind) const { - return HloInstructionPattern>( - HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const + -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { + return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); } // Modifies the pattern to match only if the instruction is a // get-tuple-element with the given tuple index. - constexpr HloInstructionPattern> - WithTupleIndex(int64 tuple_index) const { - return HloInstructionPattern>( - HloInstructionPatternTupleIndexImpl(impl_, tuple_index), - matched_inst_); + constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( + this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { + return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } private: + template + constexpr auto WithPredicate(Predicate pred) const -> decltype( + this->AppendImpl(HloPredicatePatternImpl( + std::move(pred)))) { + return AppendImpl( + HloPredicatePatternImpl(std::move(pred))); + } + + friend struct PatternFriend; + Impl impl_; HloInstructionType** matched_inst_; }; @@ -918,6 +1015,7 @@ Op(::xla::HloInstruction** matched_inst) { } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Iota) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -1004,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose) .WithOperand(0, std::forward(lhs)) \ .WithOperand(1, std::forward(rhs)); \ } -XLA_BINOP_PATTERN(Add) + +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(AnyOf(NAME(lhs, rhs), NAME(rhs, lhs))) { \ + return AnyOf(NAME(lhs, rhs), NAME(rhs, lhs)); \ + } \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(AnyOf(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs))) { \ + return AnyOf(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs)); \ + } +XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(Eq) +XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) XLA_BINOP_PATTERN(Gt) XLA_BINOP_PATTERN(Le) XLA_BINOP_PATTERN(Lt) -XLA_BINOP_PATTERN(Maximum) -XLA_BINOP_PATTERN(Minimum) -XLA_BINOP_PATTERN(Multiply) -XLA_BINOP_PATTERN(Ne) +XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) +XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) +XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) +XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) -XLA_BINOP_PATTERN(And) -XLA_BINOP_PATTERN(Or) +XLA_COMMUTATIVE_BINOP_PATTERN(And) +XLA_COMMUTATIVE_BINOP_PATTERN(Or) XLA_BINOP_PATTERN(ShiftLeft) XLA_BINOP_PATTERN(ShiftRightArithmetic) XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_COMMUTATIVE_BINOP_PATTERN #undef XLA_BINOP_PATTERN // Helpers for ternary instructions. @@ -1069,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN +namespace detail { +struct PatternFriend { + template + static auto ConstantScalar(T constant) -> decltype( + Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate( + std::declval>())) { + std::function pred = + [constant](const HloInstruction* instr) { + const auto& literal = Cast(instr)->literal(); + auto status_or_const = LiteralUtil::CreateR0(constant).Convert( + literal.shape().element_type()); + return status_or_const.ok() && + literal == status_or_const.ConsumeValueOrDie(); + }; + + return Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate(std::move(pred)); + } +}; +} // namespace detail + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); @@ -1106,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } +template +inline auto ConstantScalar(T constant) + -> decltype(detail::PatternFriend::ConstantScalar(constant)) { + return detail::PatternFriend::ConstantScalar(constant); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index a530581c34bf1d699eae3c53203c197f7943cc53..3ab7b7fd7168d7ddd1470fdb03a04ba7b171fddb 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) { EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); } +TEST(PatternMatcherTest, AnyOf) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE( + Match(root, match::AnyOf(match::ConstantScalar(0), + match::ConstantScalar(1)))); + EXPECT_TRUE( + Match(root, match::AnyOf(match::ConstantScalar(1), + match::ConstantScalar(0)))); + EXPECT_FALSE( + Match(root, match::AnyOf(match::ConstantScalar(0), + match::ConstantScalar(2)))); +} + +TEST(PatternMatcherTest, ConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE(Match(root, match::ConstantScalar(42))); + EXPECT_FALSE(Match(root, match::ConstantScalar(41))); + EXPECT_FALSE(Match(root, match::ConstantScalar(0))); +} + +TEST(PatternMatcherTest, NoMatchConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_FALSE(Match(root, match::ConstantScalar(42))); +} + +TEST(PatternMatcherTest, MultiplyAnyOrder) { + using match::ConstantScalar; + using match::MultiplyAnyOrder; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + const HloInstruction* instr; + + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); +} + +TEST(PatternMatcherTest, AnyOfShortCircuit) { + using match::AnyOf; + using match::Multiply; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf(Multiply(&mul, Op(), Op()), Op(&any)))); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(nullptr, any); + } + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf(Op(&any), Multiply(&mul, Op(), Op())))); + EXPECT_NE(nullptr, any); + EXPECT_EQ(nullptr, mul); + } +} + +TEST(PatternMatcherTest, AllOf) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); + auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); + ASSERT_TRUE(Match(root, scalar_pattern)); + ASSERT_TRUE(Match(root, f16_pattern)); + EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern))); + EXPECT_TRUE(Match(root, AllOf(f16_pattern, scalar_pattern))); + EXPECT_FALSE( + Match(root, AllOf(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE( + Match(root, AllOf(Broadcast(Op()), scalar_pattern))); +} + +TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_FALSE( + Match(root, AllOf(Constant(&constant), Broadcast(Op())))); + EXPECT_EQ(nullptr, constant); + ASSERT_TRUE(Match(root, Constant(&constant))); + EXPECT_NE(nullptr, constant); +} + +TEST(PatternMatcherTest, TestNoCapture) { + using match::Constant; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false})); + EXPECT_EQ(nullptr, constant); +} + +TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + u = f16[] parameter(0) + v = f16[] parameter(1) + ROOT add = f16[] add(u, v) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* addend0 = nullptr; + const HloInstruction* addend1 = nullptr; + const HloInstruction* addend2 = nullptr; + auto add2_pattern = Add(Op(&addend0), Op(&addend1)); + auto add3_pattern = AnyOf( + AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0)); + + ASSERT_TRUE(Match(root, add3_pattern)); + EXPECT_NE(nullptr, addend0); + EXPECT_NE(nullptr, addend1); + EXPECT_EQ(nullptr, addend2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 39fe3c7835d1c74c0f1e5bc0ebf5916ec734c24a..c522e7ae23b734090f85d241bf365fccc37f0adb 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -19,20 +19,20 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { -using tensorflow::str_util::Lowercase; - // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; @@ -43,7 +43,7 @@ constexpr char kInterpreter[] = "interpreter"; namespace { string CanonicalPlatformName(const string& name) { - string platform_str = Lowercase(name); + string platform_str = absl::AsciiStrToLower(name); // "cpu" and "host" mean the same thing. if (platform_str == "cpu") { platform_str = "host"; @@ -90,41 +90,54 @@ PlatformUtil::GetSupportedPlatforms() { if (platforms.empty()) { return NotFound("no platforms found"); } else if (platforms.size() == 1) { - return platforms[0]; + se::Platform* platform = platforms[0]; + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", - platforms_string.c_str()); + platforms_string); } /* static */ StatusOr PlatformUtil::GetDefaultPlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + + se::Platform* platform = nullptr; if (platforms.empty()) { return NotFound("no platforms found"); } else if (platforms.size() == 1) { - return platforms[0]; + platform = platforms[0]; } else if (platforms.size() == 2) { for (int i = 0; i < 2; i++) { - if (Lowercase(platforms[i]->Name()) == kInterpreter && - Lowercase(platforms[1 - i]->Name()) != kInterpreter) { - return platforms[1 - i]; + if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter && + absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) { + platform = platforms[1 - i]; + break; } } } + if (platform != nullptr) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; + } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform (except for the " "interpreter platform) found: %s", - platforms_string.c_str()); + platforms_string); } /*static*/ StatusOr PlatformUtil::GetPlatform( @@ -132,11 +145,14 @@ PlatformUtil::GetSupportedPlatforms() { string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) == platform_str) { + if (absl::AsciiStrToLower(platform->Name()) == platform_str) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } return platform; } } - return InvalidArgument("platform %s not found", platform_name.c_str()); + return InvalidArgument("platform %s not found", platform_name); } /*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( @@ -146,23 +162,27 @@ PlatformUtil::GetSupportedPlatforms() { TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); std::vector matched; for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) != platform_name) { + if (absl::AsciiStrToLower(platform->Name()) != platform_name) { matched.push_back(platform); } } if (matched.empty()) { return InvalidArgument("unable to find platform that is not %s", - platform_name.c_str()); + platform_name); } if (matched.size() == 1) { - return matched[0]; + auto platform = matched[0]; + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; } - string matched_string = tensorflow::str_util::Join( + string matched_string = absl::StrJoin( matched, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "found multiple platforms %s, but expected one platform except for %s", - matched_string.c_str(), platform_name.c_str()); + matched_string, platform_name); } // Returns whether the device underlying the given StreamExecutor is supported @@ -193,14 +213,17 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { PlatformUtil::GetStreamExecutors(se::Platform* platform) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { - return NotFound("no %s devices found", platform->Name().c_str()); + return NotFound("no %s devices found", platform->Name()); } if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware // thread. Because we parallelize a single computation across threads, it - // doesn't make sense to expose these as separate devices, so fix the number - // of devices to one. - device_count = 1; + // doesn't make sense to expose these as separate devices, so by default we + // fix the number of devices to one. However we do let the user override + // this behavior to help run tests on the host that run models in parallel + // across multiple devices. + device_count = legacy_flags::GetDebugOptionsFromFlags() + .xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; @@ -232,7 +255,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (std::all_of(stream_executors.begin(), stream_executors.end(), [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", - platform->Name().c_str()); + platform->Name()); } return stream_executors; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index afde3cf95c721b59a39b74b4e1ff3f15a335fe97..4bb22428f3d66f27d268ac4490c6e2613966cbed 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -29,7 +29,7 @@ namespace xla { // HLO pass which inserts reduce-precision instructions into the HLO graph, for // purposes of experimenting with the effects of reduced-precision storage of // intermediate values. -class ReducePrecisionInsertion : public HloPassInterface { +class ReducePrecisionInsertion : public HloModulePass { using InstructionFilterFunction = std::function; public: @@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface { ~ReducePrecisionInsertion() override{}; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "reduce-precision-insertion"; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index ca86c5d13e98a98c62d0c9e8e32e28fe99e0fa1f..4df746fca9f8320eed72911726f33bb01f06fed5 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include + +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -374,7 +376,7 @@ StatusOr TryReshapeMoveOnCandidates( removed = false; for (auto operand : nontrivial_operands) { - if (c_any_of(operand->users(), [&](HloInstruction* user) { + if (absl::c_any_of(operand->users(), [&](HloInstruction* user) { return !reshape_candidates->count(user); })) { for (auto* user : operand->users()) { diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1f59e3b3147facb6f2fae00d6c810bf54d560e5c..a3db439e34000ef3fcf4b190cb372947e285a64e 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -24,9 +24,9 @@ namespace xla { // This now only moves them outputward across elementwise ops all whose operands // are equivalent Reshapes or Transposes, but in future could potentially move // them inputward also. -class ReshapeMover : public HloPassInterface { +class ReshapeMover : public HloModulePass { public: - tensorflow::StringPiece name() const override { return "reshape-mover"; } + absl::string_view name() const override { return "reshape-mover"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index ad3b662c20ac53b0a6d634b16b3b908f730f3d2d..fcf269eee925c2ddb7511d70e71bd815e4b8c24a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -28,13 +28,13 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloVerifiedTestBase; + +namespace op = xla::testing::opcode_matchers; + +class ReshapeMoverTest : public HloVerifiedTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); @@ -76,9 +76,13 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); - auto rng0 = builder.AddInstruction( - HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), - RandomDistribution::RNG_UNIFORM, {})); + auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( + ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), + RandomDistribution::RNG_UNIFORM, + {builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(1.0f)))})); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..de7aee262e61195b37099fc661a95508d0539e18 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -0,0 +1,420 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/scatter_expander.h" + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64 index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + + if (scatter_indices_shape.dimensions_size() == index_vector_dim) { + return scatter_indices; + } + + if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +static StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); + } +} + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +static StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, absl::Span update_window_dims) { + std::vector permutation; + const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + permutation.reserve(updates_rank); + + for (int64 i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (auto window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +static StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64 index_vector_dim) { + int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +// Expands an index vector from the scatter_indices tensor into a vector that +// can be used to dynamic-update-slice to perform the scatter update. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.scatter_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +static StatusOr CheckIndexValidity( + HloComputation* computation, HloInstruction* index, + absl::Span operand_dims, absl::Span window_sizes, + HloModule* module) { + DCHECK_NE(nullptr, module); + DCHECK_EQ(operand_dims.size(), window_sizes.size()); + + // Valid range for the index: [0, operand_dims - window_sizes] + + // Check if the index has any negative values. + TF_ASSIGN_OR_RETURN( + HloInstruction * zero_index, + BroadcastZeros(computation, index->shape().element_type(), + AsInt64Slice(index->shape().dimensions()))); + TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, + MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + + // Check if the index is OOB w.r.t. the operand dimensions and window sizes. + std::vector max_valid_index(operand_dims.size()); + for (int i = 0; i < operand_dims.size(); ++i) { + max_valid_index[i] = operand_dims[i] - window_sizes[i]; + } + TF_ASSIGN_OR_RETURN( + HloInstruction * max_valid_index_constant, + MakeR1ConstantHlo(computation, index->shape().element_type(), + max_valid_index)); + TF_ASSIGN_OR_RETURN( + HloInstruction * oob_index_check, + MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + + // Combine the results of the two checks above. + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index, + MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check)); + + // Reduce the index validity check vector into a scalar predicate. + auto reduction_init = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index_reduced, + MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module)); + + // Return a broadcasted value of the scalar predicate to the same size as the + // window. + return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); +} + +// Body of the while loop that performs the scatter operation using other HLOs. +static StatusOr> ScatterLoopBody( + HloInstruction* scatter, HloInstruction* induction_var, + const std::vector& loop_state) { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK_EQ(loop_state.size(), 3); + HloInstruction* operand = loop_state[0]; + HloInstruction* scatter_indices = loop_state[1]; + HloInstruction* updates = loop_state[2]; + + bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + scatter->operand(1)->shape().dimensions_size()); + + // Build a vector form of the induction variable of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + // Pick the index to scatter from scatter_indices based on the induction_var + // and transform that to an index into the `operand` space. + HloInstruction* index_vector; + if (has_scalar_indices) { + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); + } else { + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_scatter_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + int index_vector_size = scatter_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, + {1, index_vector_size})); + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * scatter_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + // Extract the slice to be used to update from `updates` tensor for the + // induction_var corresponding to this iteration of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_updates, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/updates->shape().dimensions_size() - 1)); + std::vector update_slice_bounds(updates->shape().dimensions().begin(), + updates->shape().dimensions().end()); + update_slice_bounds[0] = 1; + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice, + MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds)); + TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice_with_dims_inserted, + InsertDegenerateDims(update_slice_for_scatter, + AsInt64Slice(dim_numbers.inserted_window_dims()))); + + // Note that the following transformation assumes that both DynamicSlice and + // DynamicUpdateSlice follow the same semantics for OOB indices. For example, + // if there are negative indices and DynamicSlice uses "clamping" semantics, + // then the extracted data will be "shifted". Since DynamicUpdateSlice also + // follows the same "clamping" semantics, writing the update will also be + // "shifted" by exactly the same amount. So, this transformation is correct as + // long as the semantics of handling OOB indices remain the same in + // DynamicSlice and DynamicUpdateSlice. + + // Extract the slice to update from `operand` tensor. + const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); + TF_ASSIGN_OR_RETURN( + HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + AsInt64Slice(update_slice_shape.dimensions()))); + + // Compute the new value for the slice to be updated in `operand` tensor by + // combining the existing value and the update value using the update + // computation. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand_slice, + MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, + scatter->to_apply())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * is_index_valid, + CheckIndexValidity( + operand->parent(), scatter_slice_start, + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()), + scatter->GetModule())); + + // Select the updated operand only if the index is valid. If not, select the + // original value. + TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply, + MakeSelectHlo(is_index_valid, updated_operand_slice, + operand_slice_to_update)); + + // Write the updated value of the slice into `operand` tensor. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start)); + + return StatusOr>{ + {updated_operand, scatter_indices, updates}}; +} + +// High Level Algorithm. +// +// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where +// each row is an index into the operand. +// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1` +// and the scatter dim is the most-major dimension. +// 3. Iterate over the set of indices in the canonicalized scatter_indices +// tensor using a while loop, updating the operand for each such index. Each +// iteration of this while loop performs the following: +// a. Pick the index from scatter_indices for this iteration. +// b. Transfrom this index into an index into the operand space. +// c. Extract the slice to be used to update from the updates tensor. +// d. Extract the slice to update from the operand tensor. +// e. Compute the new value for the slice to update by combining the slices +// from c. and d. using the update_computation of scatter. +// f. Write the updated value of the slice into the operand tensor. + +StatusOr ScatterExpander::ExpandScatter( + HloInstruction* scatter) { + HloInstruction* operand = scatter->mutable_operand(0); + HloInstruction* scatter_indices = scatter->mutable_operand(1); + HloInstruction* updates = scatter->mutable_operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensor is empty, there is no need to update the operand. We + // can return the operand as is. + if (ShapeUtil::IsZeroElementArray(updates->shape())) { + return operand; + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const Shape& scatter_indices_shape = scatter_indices->shape(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + if (!IsInt32(scatter_loop_trip_count)) { + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + canonical_scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of its most-major dimension + // must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_updates, + PermuteScatterAndWindowDims( + updates, AsInt64Slice(dim_numbers.update_window_dims()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_canonical_updates, + AdjustScatterDims(scatter_indices->shape(), canonical_updates, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + adjusted_canonical_updates->shape().dimensions(0)); + + // The while loop that implements the scatter operation. + StatusOr> scatter_loop_result_status = + WhileUtil::MakeCountedLoop( + scatter->parent(), scatter_loop_trip_count, + {operand, canonical_scatter_indices, adjusted_canonical_updates}, + [&](HloInstruction* induction_var, + const std::vector& loop_state) { + return ScatterLoopBody(scatter, induction_var, loop_state); + }); + TF_ASSIGN_OR_RETURN(std::vector scatter_loop_result, + scatter_loop_result_status); + return scatter_loop_result.front(); +} + +StatusOr ScatterExpander::Run(HloModule* module) { + std::vector scatter_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kScatter) { + scatter_instrs.push_back(instr); + } + } + } + + for (auto instr : scatter_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); + TF_RETURN_IF_ERROR( + instr->parent()->ReplaceInstruction(instr, expanded_root)); + } + + return !scatter_instrs.empty(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/service/scatter_expander.h similarity index 53% rename from tensorflow/compiler/xla/ptr_util.h rename to tensorflow/compiler/xla/service/scatter_expander.h index bfcdfc62f9541ab09b94a48d5121e16bad4d43cd..559a85dccfef27816e7dbf746fd71c44bbf46f60 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ -// As this was moved to tensorflow/core/util, provide indirections here to -// maintain current functionality of the library. +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include +namespace xla { -#include -#include -#include +class ScatterExpander : public HloModulePass { + public: + absl::string_view name() const override { return "scatter_expander"; } + StatusOr Run(HloModule* module) override; -#include "tensorflow/core/util/ptr_util.h" + private: + StatusOr ExpandScatter(HloInstruction* scatter); +}; -namespace xla { -using tensorflow::MakeUnique; -using tensorflow::WrapUnique; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 212db0643c3a49c45dc317547c8f1bfc82b7e8b0..b27a92f2a0761a2bccd97eb2c0467ead27565c37 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -20,10 +20,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -46,32 +48,29 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" - -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrCat; +#include "tensorflow/core/util/ptr_util.h" namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + // Records the arguments used to invoke a computation in an HloSnapshot proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::Stream* stream, TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordArguments(const absl::Span arguments, + se::Stream* stream, TransferManager* transfer_manager, + HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); - *module->add_arguments() = literal->ToProto(); + *module->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -81,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, result)); - *module->mutable_result() = literal->ToProto(); + *module->mutable_result() = literal.ToProto(); return Status::OK(); } @@ -147,19 +146,19 @@ Service::Service(const ServiceOptions& options, CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) << "Requested more replicas than there are devices."; } - LOG(INFO) << Printf( + LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); + execute_backend_->platform()->Name()); for (int i = 0; i < execute_backend_->device_count(); ++i) { if (execute_backend_->device_ordinal_supported(i)) { se::StreamExecutor* executor = execute_backend_->stream_executor(i).ValueOrDie(); const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); } } } else { @@ -199,16 +198,16 @@ Status Service::ValidateResultShape(const Shape& client_shape, return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(client_shape).c_str(), - ShapeUtil::HumanString(result_shape).c_str()); + ShapeUtil::HumanStringWithLayout(client_shape), + ShapeUtil::HumanString(result_shape)); } return Status::OK(); } StatusOr>> Service::ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors) { + absl::Span arguments, + absl::Span stream_executors) { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -230,9 +229,9 @@ Service::ResolveAndValidateArguments( return InvalidArgument( "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, shaped_buffer->platform()->Name().c_str(), + i, shaped_buffer->platform()->Name(), shaped_buffer->device_ordinal(), - execute_backend_->device_name(replica_device_ordinal).c_str()); + execute_backend_->device_name(replica_device_ordinal)); } replicated_arguments[replica].push_back(shaped_buffer); } @@ -242,13 +241,13 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options) { - auto config = MakeUnique(program_shape); + auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { - return InvalidArgument("computation takes %d parameters, but %zu given", + return InvalidArgument("computation takes %d parameters, but %u given", program_shape.parameters_size(), argument_shapes.size()); } @@ -260,8 +259,8 @@ StatusOr> Service::CreateModuleConfig( return InvalidArgument( "Argument does not match shape of computation parameter %d: want " "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + i, ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(*argument_shapes[i])); } TF_RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -299,7 +298,7 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { @@ -313,7 +312,7 @@ StatusOr>> Service::BuildExecutables( std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); + VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. std::vector> hlo_snapshots; @@ -325,12 +324,11 @@ StatusOr>> Service::BuildExecutables( if (directory_path.empty() && execution_directory_path.empty()) { continue; } - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = - Printf("computation_%lld__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -368,12 +366,10 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile) { + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector streams; @@ -408,7 +404,8 @@ Service::ExecuteParallelAndRegisterResult( streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.emplace_back(new se::Timer(streams.back()->parent())); + timers.push_back( + absl::make_unique(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -440,7 +437,7 @@ Service::ExecuteParallelAndRegisterResult( streams.back()->ThenStopTimer(timers.back().get()); } - result_buffers.emplace_back(std::move(result)); + result_buffers.push_back(std::move(result)); } TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, allocation_tracker_.RegisterReplicatedBuffers( @@ -452,8 +449,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < streams.size(); ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("failed to complete execution for stream %lld: %s", - i, block_status.error_message().c_str()); + return InternalError("failed to complete execution for stream %d: %s", i, + block_status.error_message()); } } @@ -511,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; @@ -555,10 +551,9 @@ StatusOr Service::ExecuteAndRegisterResult( // TODO(b/69985541): Support profiling also on this path. - std::vector> - replicated_arguments; + std::vector> replicated_arguments; for (const auto& arg : arguments) { - replicated_arguments.emplace_back(arg); + replicated_arguments.push_back(arg); } TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( @@ -578,7 +573,7 @@ StatusOr> Service::GetExecutors( if (requests_size > 1 && execution_options.device_handles_size() > 1) { return InvalidArgument( "Parallel requests with multiple device handles is not supported. " - "Found %lld parallel requests, with request %lld containing %d device " + "Found %d parallel requests, with request %d containing %d device " "handles.", requests_size, request_index, execution_options.device_handles_size()); } @@ -595,7 +590,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -743,8 +738,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%lld) exceeds the number of available devices " - "on the target (%lld)", + "Requested device count (%d) exceeds the number of available devices " + "on the target (%d)", arg->device_count(), available_device_count); } @@ -794,12 +789,12 @@ StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf( + VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, - module_proto.name().c_str()); + module_proto.name()); // Dump computation proto state if flag is set. - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); const string& directory_path = module_config->debug_options().xla_dump_computations_to(); const string& execution_directory_path = @@ -807,8 +802,8 @@ StatusOr> Service::BuildExecutable( if (!directory_path.empty() || !execution_directory_path.empty()) { *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s", module_proto.id(), - module_proto.entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_proto.id(), + module_proto.entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -817,7 +812,7 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -933,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, + Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, - result_literal->shape())) { - *result->mutable_literal() = result_literal->ToProto(); + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal->Relayout(*return_shape)->ToProto(); + result_literal.Relayout(*return_shape).ToProto(); } return Status::OK(); } @@ -953,7 +947,7 @@ namespace { // shape and DeviceMemoryBase values of the clone are identical to the original. std::unique_ptr CloneShapedBufferOnDevice( const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = MakeUnique( + auto clone = absl::make_unique( shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), shaped_buffer.platform(), device_ordinal); clone->buffers() = shaped_buffer.buffers(); @@ -964,9 +958,9 @@ std::unique_ptr CloneShapedBufferOnDevice( Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -988,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), *literal, shaped_buffer)); + stream.get(), literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1008,8 +1002,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, "%s", StrCat("The replica_id=", arg->replica_id(), " on TransferToInfeedRequest not in range [0, replica_count=", - replica_count, ").") - .c_str()); + replica_count, ").")); } se::StreamExecutor* executor; @@ -1024,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, *literal); + return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, + literal); } Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, @@ -1035,8 +1028,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( - "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " - "%lld)", + "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", arg->replica_id(), replica_count); } @@ -1052,10 +1044,11 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - Literal literal; + auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), &literal)); + executor, arg->shape_with_layout(), literal)); *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1091,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModule::CreateFromProto(arg->computation(), config)); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate>( - *module, /*arg_literals=*/{})); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( + *module, /*arg_literals=*/{})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); + result_literal = result_literal.Relayout(arg->output_layout()); } - *result->mutable_literal() = result_literal->ToProto(); + *result->mutable_literal() = result_literal.ToProto(); return Status::OK(); } @@ -1168,7 +1160,7 @@ StatusOr> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1176,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47d196fb2aaee897ce1fd3745129af10bf5b2d2d..1f62fad4c8079eba7013b3f647fe19bbc031fc77 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -176,7 +176,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); protected: friend class LocalExecutable; @@ -207,14 +207,14 @@ class Service : public ServiceInterface { // the corresponding replica. StatusOr>> ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span arguments, + absl::Span stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. @@ -242,21 +242,17 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, - tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile); + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile); // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which @@ -275,7 +271,9 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c888bbf144954c3b48afecf80a8884e847cc9d18..7194b2cafd348c144a2ee83027cf78642bfaf75f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -21,6 +21,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -28,44 +33,37 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -using tensorflow::str_util::Join; -using tensorflow::strings::Printf; - namespace xla { - namespace { +using absl::StrFormat; +using absl::StrJoin; + // Returns true if no element is present in slice more than once. -bool AllUnique(tensorflow::gtl::ArraySlice slice) { +bool AllUnique(absl::Span slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { +Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!ShapeUtil::IsArray(shape)) { return InvalidArgument("Expected array argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); + string(op_type), ShapeUtil::HumanString(shape)); } return Status::OK(); } -Status VerifyReducerShape( - const ProgramShape& reducer_shape, - tensorflow::gtl::ArraySlice init_value_shapes, - tensorflow::gtl::ArraySlice input_element_types, - int64 inputs) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + absl::Span init_value_shapes, + absl::Span input_element_types, + int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take %lld parameters, but " + "Reduction function must take %d parameters, but " "takes %d parameter(s).", inputs * 2, reducer_shape.parameters_size()); } @@ -75,7 +73,7 @@ Status VerifyReducerShape( if (ShapeUtil::IsArray(accumulator_shape)) { if (inputs != 1) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but " + "Reduction function must produce a tuple with %d elements, but " "produces a scalar", inputs); } @@ -83,8 +81,8 @@ Status VerifyReducerShape( } else if (ShapeUtil::IsTuple(accumulator_shape)) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but has " - "%lld elements", + "Reduction function must produce a tuple with %d elements, but has " + "%d elements", inputs, ShapeUtil::TupleElementCount(accumulator_shape)); } for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { @@ -94,7 +92,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must produce a scalar or tuple of scalars, but has " "shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } for (const Shape* element_shape : accumulator_subshapes) { @@ -102,7 +100,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } } @@ -113,19 +111,19 @@ Status VerifyReducerShape( if (!ShapeUtil::Compatible(*accumulator_subshapes[i], reducer_shape.parameters(i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "result shape: %s vs %s", - i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + i, ShapeUtil::HumanString(reducer_shape.parameters(i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } // Check that init_value's shapes are suitable for reducer_shape. if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], *init_value_shapes[i])) { return InvalidArgument( - "Reduction function's accumulator shape at index %lld differs from " + "Reduction function's accumulator shape at index %d differs from " "the init_value shape: %s vs %s", - i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), - ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + i, ShapeUtil::HumanString(*accumulator_subshapes[i]), + ShapeUtil::HumanString(*init_value_shapes[i])); } // Check that the inputs can be passed in as the non-accumulator arguments. const Shape input_element_shape = @@ -133,11 +131,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( input_element_shape, reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "input type element type: %s vs %s", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(input_element_shape)); } // Check that the accumulator and inputs to the reducer function match. // If the accumulator is scalar, it must have the same type as the inputs @@ -147,11 +145,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape must " + "Reduction function's %d-th parameter shape must " "match the result shape, but got %s vs %s.", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } } @@ -164,7 +162,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, bool allow_negative_padding) { if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { return InvalidArgument( - "Window has dimension %d but base shape has dimension %lld.", + "Window has dimension %d but base shape has dimension %d.", window.dimensions_size(), ShapeUtil::Rank(base_shape)); } @@ -173,29 +171,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const auto& dim = window.dimensions(i); if (dim.size() <= 0) { return InvalidArgument("Window %s has a non-positive dimension.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.stride() <= 0) { return InvalidArgument("Window %s has a non-positive stride.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_low() < 0) { return InvalidArgument("Window %s has a negative low padding.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_high() < 0) { return InvalidArgument("Window %s has a negative high padding.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.base_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive base area dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.window_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive window dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } const int64 dilated_base = window_util::DilatedBound( @@ -233,11 +231,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kRoundNearestAfz: if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating for floor/ceil " - "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating for %s operation; " + "got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kCos: @@ -250,9 +249,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "Expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating or complex for %s " + "operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kReal: @@ -264,19 +263,47 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } else { return InvalidArgument( "Expected element type in shape to be floating or complex for " - "real/imag operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( shape, primitive_util::ComplexComponentType(shape.element_type())); + } else if (ShapeUtil::ElementIsSigned(shape)) { + return shape; + } else { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } - return shape; case HloOpcode::kClz: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to Clz " + "operation; got %s.", + PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kNegate: - case HloOpcode::kRoundNearestAfz: + if (!ShapeUtil::ElementIsIntegral(shape) && + !ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be integral, floating or " + "complex for %s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kSign: + if (!ShapeUtil::ElementIsSigned(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be signed or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } return shape; case HloOpcode::kNot: @@ -285,7 +312,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return shape; @@ -295,25 +322,24 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, "Expected element type in shape to be floating " "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, - const int64 dimension) { + absl::Span arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("Concatenate dimension out of bounds: %lld.", + return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } const Shape* arg_shape = nullptr; @@ -327,17 +353,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), - ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), - ShapeUtil::HumanString(*shape).c_str()); + ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), + ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "Cannot concatenate arrays with different element types: %s vs %s.", - PrimitiveType_Name(arg_shape->element_type()).c_str(), - PrimitiveType_Name(shape->element_type()).c_str()); + PrimitiveType_Name(arg_shape->element_type()), + PrimitiveType_Name(shape->element_type())); } for (int64 dimension_number = 0; dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { @@ -350,9 +375,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld.", - ShapeUtil::HumanString(*arg_shape).c_str(), - ShapeUtil::HumanString(*shape).c_str(), dimension); + "the same): %s vs %s in dimension %d.", + ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape), + dimension); } } element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); @@ -367,7 +392,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr ShapeInference::InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes) { + absl::Span arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { return InvalidArgument( @@ -384,8 +409,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( "Conversion from complex to real type %s => %s is not implemented.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -394,8 +419,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Convert does not allow non-arrays, so cannot convert from %s to %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -407,8 +432,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { return InvalidArgument("Conversion from complex to real type %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -417,15 +442,15 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Cannot convert from or to tuple type; requested conversion: %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( "Cannot bitcast types with different bit-widths: %s => %s.", - PrimitiveType_Name(old_element_type).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + PrimitiveType_Name(old_element_type), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -438,7 +463,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating point for " "ReducePrecision operation; got %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having @@ -470,21 +495,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - padding_config.ShortDebugString().c_str()); + ShapeUtil::HumanString(operand_shape), + padding_config.ShortDebugString()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( "The element types of the operands to Pad do not match."); } + if (absl::c_any_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& p) { + return p.interior_padding() < 0; + })) { + return InvalidArgument("Interior padding cannot be negative: %s", + padding_config.ShortDebugString()); + } + std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { - dimensions[i] = operand_shape.dimensions(i) + - padding_config.dimensions(i).edge_padding_low() + - padding_config.dimensions(i).edge_padding_high() + + const auto& p = padding_config.dimensions(i); + dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * - padding_config.dimensions(i).interior_padding(); + p.interior_padding(); } return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), @@ -515,22 +548,22 @@ Status ValidateDotDimensionNumbers( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { // Check that dimension numbers are in range. - auto dims_in_range = - [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_in_range = [](const int64 rank, + absl::Span contracting_dims, + absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; return std::all_of(contracting_dims.begin(), contracting_dims.end(), in_range) && std::all_of(batch_dims.begin(), batch_dims.end(), in_range); }; - tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + absl::Span lhs_contracting_dimensions = AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + absl::Span rhs_contracting_dimensions = AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice lhs_batch_dimensions = + absl::Span lhs_batch_dimensions = AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); - tensorflow::gtl::ArraySlice rhs_batch_dimensions = + absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, @@ -538,12 +571,12 @@ Status ValidateDotDimensionNumbers( !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that dimension numbers are unique. - auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_unique = [](absl::Span contracting_dims, + absl::Span batch_dims) -> bool { tensorflow::gtl::FlatSet dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; @@ -556,7 +589,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is not unique in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. @@ -601,14 +634,13 @@ Status ValidateDotDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { - string message = tensorflow::strings::Printf( - "Cannot infer shape for dot operation: %s %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + string message = + StrFormat("Cannot infer shape for dot operation: %s %s.", + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); if (!addendum.empty()) { message += " " + addendum; } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); }; // Check if both element types are the same. @@ -704,9 +736,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -715,20 +746,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. return InvalidArgument("Automatic shape inference not supported: %s and %s", - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " - " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu.", + " lower-rank operand's rank is %d, size of broadcast_dimensions is " + "%u.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -778,12 +809,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "Broadcast dimension number (%lld) cannot be negative.", + "Broadcast dimension number (%d) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "Broadcast dimension number (%lld) too large; higher-rank " + "Broadcast dimension number (%d) too large; higher-rank " "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } @@ -795,16 +826,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, + "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, small_dimension_size, large_dimension_size, - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "Broadcast dimensions order is wrong: %lld comes after %lld.", + "Broadcast dimensions order is wrong: %d comes after %d.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -816,15 +847,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -873,21 +904,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - VLOG(2) << tensorflow::strings::Printf( + absl::Span broadcast_dimensions) { + VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), - Join(broadcast_dimensions, ", ").c_str()); + HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", ")); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR( - ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", - HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR( - ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", - HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -909,7 +937,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected element type in shape to be floating for complex compose " "operation; got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, @@ -928,7 +956,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected pred or integral type in argument to and/or operation; " "got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -946,8 +974,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode), lhs.ShortDebugString(), + rhs.ShortDebugString()); } } @@ -970,14 +998,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kTupleSelect: return InferTupleSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + HloOpcode opcode, absl::Span operands) { std::vector operand_shapes; operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { @@ -987,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes) { + HloOpcode opcode, absl::Span operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } @@ -1010,8 +1035,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Sort keys and values dimensions must match. " "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), - ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), + ShapeUtil::HumanString(*operand_shapes[1])); } return ShapeUtil::MakeTupleShape( {*operand_shapes[0], *operand_shapes[1]}); @@ -1019,15 +1044,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Unexpected number of operands for sort"); } default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument."); } @@ -1058,7 +1081,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s.", - Join(pieces, ", ").c_str()); + StrJoin(pieces, ", ")); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1066,7 +1089,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", + "arg_dimension_size: %d, requested_map_dimensions_size: %u.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1075,7 +1098,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers; got: %s.", - Join(dimensions, ", ").c_str()); + StrJoin(dimensions, ", ")); } } @@ -1083,7 +1106,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu.", + "arity: %d, arguments: %u.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1092,7 +1115,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( "Mapped computation's result has to be a scalar; got: %s.", - ShapeUtil::HumanString(output_shape).c_str()); + ShapeUtil::HumanString(output_shape)); } for (int i = 0; i < to_apply.parameters_size(); ++i) { @@ -1102,7 +1125,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter has to be a scalar; " "got parameter %d shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, @@ -1110,8 +1133,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str(), - ShapeUtil::HumanString(*arg_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape), + ShapeUtil::HumanString(*arg_shape)); } } @@ -1140,35 +1163,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld.", + "batch-norm-training to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1176,7 +1199,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-training must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1185,8 +1208,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1195,8 +1218,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1206,16 +1229,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1250,35 +1273,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld.", + "batch-norm-inference to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1286,7 +1309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-inference must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1296,8 +1319,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1307,8 +1330,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1318,8 +1341,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of mean is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, @@ -1329,8 +1352,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of variance is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(variance_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(variance_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1340,32 +1363,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of mean is %lld " - "and the feature count is %lld.", + "but the size of mean is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1395,36 +1418,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" - " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld.", + " output_grad_shape; got rank(oprand_shape) %d, and" + " rank(output_grad_shape) %d.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } if (ShapeUtil::Rank(mean_shape) != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(mean_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } if (ShapeUtil::Rank(var_shape) != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(var_shape)); } @@ -1432,14 +1455,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, @@ -1448,8 +1471,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " "and the element type of operand is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1458,8 +1481,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " "and the element type of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1468,8 +1491,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, @@ -1478,8 +1501,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1490,24 +1513,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1517,8 +1540,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," - "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld.", + "but the bound of operand_shape at dimension %d is %d " + "and the bound of output_grad_shape is %d.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1529,23 +1552,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); } if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" "Window: %s", - window.DebugString().c_str()); + window.DebugString()); } const int num_spatial_dims = dnums.input_spatial_dimensions_size(); @@ -1553,19 +1575,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" "Window: %s\nDimension numbers: %s.", - window.DebugString().c_str(), dnums.DebugString().c_str()); + window.DebugString(), dnums.DebugString()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1602,26 +1624,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } std::vector input_spatial_dims(num_spatial_dims); @@ -1640,14 +1662,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features) { + if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( - "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension (value %lld); got (%s, %s)\n" + "Expected LHS feature dimension (value %d) to match RHS " + "input feature dimension * feature_group_count (value %d * %d = %d); " + "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); + input_features, kernel_input_features, feature_group_count, + kernel_input_features * feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } + if (kernel_output_features % feature_group_count > 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "feature_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); } std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1659,8 +1693,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "RHS shape: %s\n\t" "Window: {%s}\n\t" "Dimension numbers: {%s}.", - ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), - dnums.ShortDebugString().c_str()); + ShapeUtil::HumanString(rhs), window.ShortDebugString(), + dnums.ShortDebugString()); } Shape base_shape = @@ -1683,32 +1717,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); } -#define RET_CHECK_RANK(x) \ - if (x.dimensions_size() < fft_rank) { \ - return InvalidArgument( \ - "FFT of rank %lld requires input of at least " \ - "same rank; got input of rank %d", \ - fft_rank, x.dimensions_size()); \ +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %d requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ } switch (fft_type) { case FFT: case IFFT: if (in.element_type() != C64) { return InvalidArgument("%s requires C64 input type, found %s.", - FftType_Name(fft_type).c_str(), - PrimitiveType_Name(in.element_type()).c_str()); + FftType_Name(fft_type), + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); return in; case RFFT: { if (in.element_type() != F32) { return InvalidArgument("RFFT requires F32 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); for (int i = 0; i < fft_rank; i++) { @@ -1716,7 +1750,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1730,7 +1764,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case IRFFT: { if (in.element_type() != C64) { return InvalidArgument("IRFFT requires C64 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); Shape result = ShapeUtil::ComplexComponentShape(in); @@ -1739,7 +1773,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld.", + "fft_length, but dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1749,7 +1783,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1764,7 +1798,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of cross replica sum")); @@ -1779,9 +1813,60 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %d is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape)); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %d is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape)); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %d must be dividable by split_count " + "%d.", + shape.dimensions(split_dimension), split_count); + } + std::vector new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( + absl::Span operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]), i, + ShapeUtil::HumanString(*operand_shapes[i])); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + +/* static */ StatusOr ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr ShapeInference::InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply) { if (arg_shapes.empty()) { return InvalidArgument("Reduce must have at least 2 arguments, has 0"); @@ -1793,17 +1878,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 num_reduced_args = arg_shapes.size() / 2; - tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, - num_reduced_args); + auto reduced_args = arg_shapes.subspan(0, num_reduced_args); // Check that all of the reduced tensors have the same dimensions. The element // types may be different. for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( "All reduced tensors must have the sime dimension. Tensor 0 has " - "shape %s, Tensor %lld has shape %s", - ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, - ShapeUtil::HumanString(*reduced_args[i]).c_str()); + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*reduced_args[0]), i, + ShapeUtil::HumanString(*reduced_args[i])); } } @@ -1813,14 +1897,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { - return InvalidArgument( - "Reducing out-of-bounds dimension %lld in shape %s.", dimension, - ShapeUtil::HumanString(arg).c_str()); + return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", + dimension, ShapeUtil::HumanString(arg)); } } - tensorflow::gtl::ArraySlice init_values( - arg_shapes, num_reduced_args, arg_shapes.size()); + auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size()); std::vector element_types; for (const Shape* arg : reduced_args) { element_types.push_back(arg->element_type()); @@ -1888,16 +1970,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select function's first parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(0)), + ShapeUtil::HumanString(operand_element_shape)); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( "Select function's second parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(1)), + ShapeUtil::HumanString(operand_element_shape)); } // Check if the scatter function has a proper shape as a reduction. @@ -1915,43 +1997,40 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s).", - ShapeUtil::HumanString(source_shape).c_str(), - ShapeUtil::HumanString(window_result_shape).c_str()); + ShapeUtil::HumanString(source_shape), + ShapeUtil::HumanString(window_result_shape)); } return operand_shape; } /* static */ StatusOr ShapeInference::InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides) { + const Shape& arg, absl::Span starts, + absl::Span limits, absl::Span strides) { auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " "{%s}; strides: {%s}.", - message.c_str(), ShapeUtil::HumanString(arg).c_str(), - Join(starts, ",").c_str(), Join(limits, ",").c_str(), - Join(strides, ",").c_str()); + message, ShapeUtil::HumanString(arg), StrJoin(starts, ","), + StrJoin(limits, ","), StrJoin(strides, ",")); }; TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); - VLOG(2) << tensorflow::strings::Printf( - "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), - Join(limits, ", ").c_str()); + VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}", + ShapeUtil::HumanString(arg), StrJoin(starts, ", "), + StrJoin(limits, ", ")); if (starts.size() != limits.size()) { - return error(Printf("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size())); + return error(StrFormat("slice start and limit sizes differ: %u vs %u", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return error(Printf("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size())); + return error(StrFormat("slice start and strides sizes differ: %u vs %u", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "Slice index count does not match argument rank: %zu vs %lld.", + "Slice index count does not match argument rank: %u vs %d.", starts.size(), ShapeUtil::Rank(arg)); } @@ -1961,27 +2040,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("Negative start index to slice: %lld.", - start_index); + return InvalidArgument("Negative start index to slice: %d.", start_index); } if (limit_index > arg.dimensions(dimension)) { return error( - Printf("limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension))); - } - VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, - start_index); - VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, - limit_index); + StrFormat("limit index (%d) must be less than or equal to dimension " + "size (%d)", + limit_index, arg.dimensions(dimension))); + } + VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); + VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); if (start_index > limit_index) { return error( - Printf("limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index)); + StrFormat("limit index (%d) must be greater or equal to " + "start index (%d) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("Stride (%lld) must be positive.", stride); + return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -1991,20 +2067,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - Join(slice_sizes, ", ").c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2016,16 +2091,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "Dynamic slice index count does not match argument rank: %zu vs %lld.", + "Dynamic slice index count does not match argument rank: %u vs %d.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2033,16 +2107,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("Negative size index to dynamic slice: %lld.", + return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "Slice dim size %lld greater than dynamic slice dimension: %lld.", + "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, - slice_dim_size); + VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); @@ -2058,16 +2131,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, "start indices of dynamic update slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "updating slice of shape %s at dynamic start_indices %s with update " "shape %s", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::HumanString(update_shape).c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2079,17 +2152,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic update slice start number of dimensions %lld (%s) must match " - "rank %lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " - "%lld vs %lld.", + "%d vs %d.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } @@ -2098,8 +2170,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str(), - PrimitiveType_Name(update_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type()), + PrimitiveType_Name(update_shape.element_type())); } for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { @@ -2107,23 +2179,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "Size index %lld to dynamic update slice must be >= 0.", + "Size index %d to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "Update dim size %lld greater than dynamic slice dimension: %lld.", + "Update dim size %d greater than dynamic slice dimension: %d.", update_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, - update_dim_size); + VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size); } return operand_shape; } /*static */ StatusOr ShapeInference::InferReverseShape( - const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand_shape, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); @@ -2131,8 +2202,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", - dimension, ShapeUtil::HumanString(operand_shape).c_str()); + "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", + dimension, ShapeUtil::HumanString(operand_shape)); } } return operand_shape; @@ -2143,14 +2214,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", - ShapeUtil::HumanString(arg).c_str()); + ShapeUtil::HumanString(arg)); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "Cannot infer shape: attempt to index out of tuple bounds: %lld " + "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", - index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); + index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg)); } return arg.tuple_shapes(index); @@ -2170,17 +2241,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } auto shape_string = [&]() { - return tensorflow::strings::Printf( - "Condition: %s; body: %s; init: %s.", - ShapeUtil::HumanString(condition).c_str(), - ShapeUtil::HumanString(body).c_str(), - ShapeUtil::HumanString(init).c_str()); + return StrFormat( + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition), + ShapeUtil::HumanString(body), ShapeUtil::HumanString(init)); }; // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { return InvalidArgument("Condition must return a boolean; got %s.", - shape_string().c_str()); + shape_string()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || @@ -2188,7 +2257,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The parameter of condition and body, the result of the body, and init " "must all have the same shape; got %s.", - shape_string().c_str()); + shape_string()); } return init; @@ -2200,7 +2269,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { return InvalidArgument("Predicate must be a boolean; got %s.", - ShapeUtil::HumanString(predicate).c_str()); + ShapeUtil::HumanString(predicate)); } if (true_computation.parameters_size() != 1) { @@ -2209,15 +2278,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { auto true_shape_string = [&]() { - return tensorflow::strings::Printf( - "true_operand: %s; true_computation: %s", - ShapeUtil::HumanString(true_operand).c_str(), - ShapeUtil::HumanString(true_computation).c_str()); + return StrFormat("true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand), + ShapeUtil::HumanString(true_computation)); }; return InvalidArgument( "true_operand must match the shape of the only parameter of " "true_computation: got %s.", - true_shape_string().c_str()); + true_shape_string()); } if (false_computation.parameters_size() != 1) { @@ -2226,38 +2294,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { auto false_shape_string = [&]() { - return tensorflow::strings::Printf( - "false_operand: %s; false_computation: %s", - ShapeUtil::HumanString(false_operand).c_str(), - ShapeUtil::HumanString(false_computation).c_str()); + return StrFormat("false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand), + ShapeUtil::HumanString(false_computation)); }; return InvalidArgument( "false_operand must match the shape of the only parameter of " "false_computation: got %s.", - false_shape_string().c_str()); + false_shape_string()); } if (!ShapeUtil::Compatible(true_computation.result(), false_computation.result())) { auto shape_string = [&]() { - return tensorflow::strings::Printf( + return StrFormat( "true_computation result: %s; false_computation result: %s.", - ShapeUtil::HumanString(true_computation.result()).c_str(), - ShapeUtil::HumanString(false_computation.result()).c_str()); + ShapeUtil::HumanString(true_computation.result()), + ShapeUtil::HumanString(false_computation.result())); }; return InvalidArgument( "the result of true_computation and false_computation must have the " "same shape: got %s.", - shape_string().c_str()); + shape_string()); } return true_computation.result(); } /* static */ StatusOr ShapeInference::InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + const Shape& operand, absl::Span broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { - return InvalidArgument("Broadcast with negative dimension size %lld.", + return InvalidArgument("Broadcast with negative dimension size %d.", size); } } @@ -2271,8 +2338,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + const Shape& operand, absl::Span dimensions, + absl::Span new_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = @@ -2282,11 +2349,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "Reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s).", - ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), ShapeUtil::ElementsIn(inferred_shape), - ShapeUtil::HumanString(inferred_shape).c_str()); + ShapeUtil::HumanString(inferred_shape)); } std::vector indices(ShapeUtil::Rank(operand)); @@ -2297,14 +2364,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } return inferred_shape; } /* static */ StatusOr ShapeInference::InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); @@ -2332,9 +2399,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", - ShapeUtil::HumanString(min).c_str(), - ShapeUtil::HumanString(operand).c_str(), - ShapeUtil::HumanString(max).c_str()); + ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); } if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || ShapeUtil::IsScalar(min)) && @@ -2351,9 +2418,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::ChangeElementType(min, operand.element_type()); } } - return Unimplemented( - "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), - max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); + return Unimplemented("%s, %s %s is not implemented.", + min.ShortDebugString(), max.ShortDebugString(), + operand.ShortDebugString()); } // TODO(b/36794510): Make broadcast semantics more consistent, by supporting @@ -2364,13 +2431,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "Select's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || ShapeUtil::IsScalar(pred)) { @@ -2383,7 +2449,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select operation with non-scalar predicate with dimensionality " " different from the other operands: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } } @@ -2394,38 +2460,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::Compatible(on_true, on_false)) { return InvalidArgument( "Operands to tuple-select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "TupleSelect's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (!ShapeUtil::IsScalar(pred)) { return InvalidArgument( "TupleSelect operation with non-scalar predicate: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } return on_true; } /* static */ StatusOr ShapeInference::InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply) { + absl::Span arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); string argument_shapes = - Join(arg_shapes, ", ", [](string* out, const Shape* shape) { - tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) { + absl::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu; computation signature: %s; argument " + "arity: %d, arguments: %u; computation signature: %s; argument " "shapes: [%s].", - to_apply.parameters_size(), arg_shapes.size(), - computation_signature.c_str(), argument_shapes.c_str()); + to_apply.parameters_size(), arg_shapes.size(), computation_signature, + argument_shapes); } // All arguments must be compatible with the program shape. @@ -2436,8 +2500,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " "argument shape: %s.", - i, ShapeUtil::HumanString(param_shape).c_str(), - ShapeUtil::HumanString(arg_shape).c_str()); + i, ShapeUtil::HumanString(param_shape), + ShapeUtil::HumanString(arg_shape)); } } @@ -2445,202 +2509,198 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } static Status ValidateGatherDimensionNumbers( - const Shape& input_shape, - tensorflow::gtl::ArraySlice gather_indices_shape, + const Shape& input_shape, absl::Span start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.output_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.output_window_dims()) != - dim_numbers.output_window_dims().end()) { + if (absl::c_adjacent_find(dim_numbers.offset_dims()) != + dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } - const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size() - 1; + output_offset_dim_count + start_indices_shape.size() - 1; - for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { - int64 window_index = dim_numbers.output_window_dims(i); - if (window_index < 0 || window_index >= output_shape_rank) { + for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) { + int64 offset_dim = dim_numbers.offset_dims(i); + if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Window index %d in gather op is out of bounds; got %lld, but should " - "have been in [0,%lld).", - i, window_index, output_shape_rank); + "Offset dimension %d in gather op is out of bounds; got %d, but " + "should " + "have been in [0,%d).", + i, offset_dim, output_shape_rank); } } - if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape[dim_numbers.index_vector_dim()]) { + if (dim_numbers.start_index_map_size() != + start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "Gather op has %d elements in gather_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of gather_indices is " - "%lld. These two numbers must be equal.", - dim_numbers.gather_dims_to_operand_dims_size(), - dim_numbers.index_vector_dim(), - gather_indices_shape[dim_numbers.index_vector_dim()]); + "Gather op has %d elements in start_index_map and the " + "bound of dimension index_vector_dim=%d of start_indices is " + "%d. These two numbers must be equal.", + dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), + start_indices_shape[dim_numbers.index_vector_dim()]); } - for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { - int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); - if (gather_dim_to_input_dim < 0 || - gather_dim_to_input_dim >= input_shape.dimensions_size()) { + for (int i = 0; i < dim_numbers.start_index_map_size(); i++) { + int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i); + if (operand_dim_for_start_index_i < 0 || + operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", - input_shape.dimensions_size(), i, gather_dim_to_input_dim); + "Invalid start_index_map; domain is [0, %d), got: %d->%d.", + input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } - std::vector sorted_gather_dims_to_operand_dims( - dim_numbers.gather_dims_to_operand_dims().begin(), - dim_numbers.gather_dims_to_operand_dims().end()); + std::vector sorted_start_index_map( + dim_numbers.start_index_map().begin(), + dim_numbers.start_index_map().end()); - c_sort(sorted_gather_dims_to_operand_dims); + absl::c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != - sorted_gather_dims_to_operand_dims.end()) { + if (absl::c_adjacent_find(sorted_start_index_map) != + sorted_start_index_map.end()) { return InvalidArgument( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.start_index_map(), ", ")); } - for (int64 elided_dim : dim_numbers.elided_window_dims()) { - if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { + if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid elided_window_dims set in gather op; valid range is [0, " - "%d), got: %lld.", - input_shape.dimensions_size(), elided_dim); + "Invalid collapsed_slice_dims set in gather op; valid range is [0, " + "%d), got: %d.", + input_shape.dimensions_size(), collapsed_dim); } } - if (!c_is_sorted(dim_numbers.elided_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( - "elided_window_dims in gather op must be sorted; got: %s", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + "collapsed_slice_dims in gather op must be sorted; got: %s", + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.elided_window_dims()) != - dim_numbers.elided_window_dims().end()) { + if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( - "Repeated dimensions not allowed in elided_window_dims in gather op; " + "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } return Status::OK(); } /*static*/ StatusOr ShapeInference::InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( - ExpectArray(gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(start_indices_shape, "gather indices operand of gather op")); - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape)); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if // index_vector_dim is rank(P). The bounds of this expanded shape is - // stored in expanded_gather_indices_shape. + // stored in expanded_start_indices_shape. - if (gather_indices_shape.dimensions_size() < + if (start_indices_shape.dimensions_size() < gather_dim_numbers.index_vector_dim() || gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather index leaf dimension must be within [0, rank(gather_indices) + " - "1). rank(gather_indices) is %d and gather index leaf dimension is " - "%lld.", - gather_indices_shape.dimensions_size(), + "Gather index leaf dimension must be within [0, rank(start_indices) + " + "1). rank(start_indices) is %d and gather index leaf dimension is " + "%d.", + start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } - std::vector expanded_gather_indices_shape; - expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); - c_copy(gather_indices_shape.dimensions(), - std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == + std::vector expanded_start_indices_shape; + expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); + absl::c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); + if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { - expanded_gather_indices_shape.push_back(1); + expanded_start_indices_shape.push_back(1); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( - input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + input_shape, expanded_start_indices_shape, gather_dim_numbers)); - if (window_bounds.size() != input_shape.dimensions_size()) { + if (slice_sizes.size() != input_shape.dimensions_size()) { return InvalidArgument( - "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d.", - window_bounds.size(), input_shape.dimensions_size()); + "Gather op must have one slice size for every input dimension; got: " + "len(slice_sizes)=%lu, input_shape.rank=%d.", + slice_sizes.size(), input_shape.dimensions_size()); } - if (window_bounds.size() != - gather_dim_numbers.output_window_dims_size() + - gather_dim_numbers.elided_window_dims_size()) { + if (slice_sizes.size() != + gather_dim_numbers.offset_dims_size() + + gather_dim_numbers.collapsed_slice_dims_size()) { return InvalidArgument( - "All components of the window index in a gather op must either be a " - "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s.", - window_bounds.size(), - Join(gather_dim_numbers.output_window_dims(), ",").c_str(), - Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + "All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " + "output_slice_sizes=%s, collapsed_slice_dims=%s.", + slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","), + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",")); } - for (int i = 0; i < window_bounds.size(); i++) { - int64 window_bound = window_bounds[i]; - int64 corresponding_input_bound = input_shape.dimensions(i); - if (window_bound < 0 || window_bound > corresponding_input_bound) { + for (int i = 0; i < slice_sizes.size(); i++) { + int64 slice_size = slice_sizes[i]; + int64 corresponding_input_size = input_shape.dimensions(i); + if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( - "Window bound at index %d in gather op is out of range, must be " - "within " - "[0, %lld), got %lld.", - i, corresponding_input_bound + 1, window_bound); + "Slice size at index %d in gather op is out of range, must be " + "within [0, %d), got %d.", + i, corresponding_input_size + 1, slice_size); } } - for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { - if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) { + if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( - "Gather op can only elide window indices with bound 1, but bound is " - "%lld for index %lld at position %d.", - window_bounds[gather_dim_numbers.elided_window_dims(i)], - gather_dim_numbers.elided_window_dims(i), i); + "Gather op can only collapse slice dims with bound 1, but bound is " + "%d for index %d at position %d.", + slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], + gather_dim_numbers.collapsed_slice_dims(i), i); } } - int64 result_rank = gather_dim_numbers.output_window_dims_size() + - (expanded_gather_indices_shape.size() - 1); - int64 window_dims_seen = 0; + int64 result_rank = gather_dim_numbers.offset_dims_size() + + (expanded_start_indices_shape.size() - 1); + int64 offset_dims_seen = 0; int64 gather_dims_seen = 0; std::vector output_dim_bounds; output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; bool is_window_index = - c_binary_search(gather_dim_numbers.output_window_dims(), i); + absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.elided_window_dims(), - window_dims_seen)) { - window_dims_seen++; + while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { + offset_dims_seen++; } - current_bound = window_bounds[window_dims_seen++]; + current_bound = slice_sizes[offset_dims_seen++]; } else { if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { gather_dims_seen++; } - current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_bounds.push_back(current_bound); @@ -2652,48 +2712,47 @@ static Status ValidateGatherDimensionNumbers( namespace { Status ValidateScatterDimensionNumbers( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& operand_shape, absl::Span scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.update_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.update_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } const int64 updates_rank = ShapeUtil::Rank(updates_shape); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( "Invalid update_window_dims set in scatter op; valid range is [0, " - "%lld). got: %lld.", + "%d). got: %d.", updates_rank, window_dim); } } // Validate inserted_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid inserted_window_dims set in scatter op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", operand_shape.dimensions_size(), inserted_dim); } } @@ -2703,7 +2762,7 @@ Status ValidateScatterDimensionNumbers( scatter_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Scatter op has %d elements in scatter_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "bound of dimension index_vector_dim=%d of scatter_indices is %d. " "These two numbers must be equal.", dim_numbers.scatter_dims_to_operand_dims_size(), dim_numbers.index_vector_dim(), @@ -2716,20 +2775,20 @@ Status ValidateScatterDimensionNumbers( scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", + "got: %d->%d.", operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); } } std::vector sorted_scatter_dims_to_operand_dims( dim_numbers.scatter_dims_to_operand_dims().begin(), dim_numbers.scatter_dims_to_operand_dims().end()); - c_sort(sorted_scatter_dims_to_operand_dims); - if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + absl::c_sort(sorted_scatter_dims_to_operand_dims); + if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) != sorted_scatter_dims_to_operand_dims.end()) { return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " "got: %s.", - Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ")); } return Status::OK(); @@ -2750,7 +2809,7 @@ Status ValidateScatterDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { return InvalidArgument( "Scatter indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(scatter_indices_shape).c_str()); + ShapeUtil::HumanString(scatter_indices_shape)); } if (scatter_indices_shape.dimensions_size() < @@ -2759,7 +2818,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Scatter index leaf dimension must be within [0, rank(scatter_indices)" " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " - "is %lld.", + "is %d.", scatter_indices_shape.dimensions_size(), scatter_dim_numbers.index_vector_dim()); } @@ -2781,7 +2840,7 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { - return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + return InvalidArgument("Updates tensor must be of rank %d; got %d.", expected_updates_rank, ShapeUtil::Rank(updates_shape)); } @@ -2791,32 +2850,32 @@ Status ValidateScatterDimensionNumbers( scatter_dim_numbers)); int64 inserted_dims_seen = 0; - std::vector max_update_window_bounds; + std::vector max_update_slice_sizes; for (int i = 0; i < operand_shape.dimensions_size(); ++i) { if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { ++inserted_dims_seen; } else { - max_update_window_bounds.push_back(operand_shape.dimensions(i)); + max_update_slice_sizes.push_back(operand_shape.dimensions(i)); } } for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { auto update_window_dim = scatter_dim_numbers.update_window_dims(i); if (updates_shape.dimensions(update_window_dim) > - max_update_window_bounds[i]) { + max_update_slice_sizes[i]) { return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " - "%lld, updates bound is %lld, operand bound is %lld.", + "%d, updates bound is %d, operand bound is %d.", update_window_dim, updates_shape.dimensions(update_window_dim), - max_update_window_bounds[i]); + max_update_slice_sizes[i]); } } int64 scatter_dims_seen = 0; for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { bool is_update_window_dim = - c_binary_search(scatter_dim_numbers.update_window_dims(), i); + absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { continue; } @@ -2828,8 +2887,8 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices. For " - "scatter dimension %lld, updates bound is %lld, scatter_indices " - "bound is %lld.", + "scatter dimension %d, updates bound is %d, scatter_indices " + "bound is %d.", i, updates_shape.dimensions(i), expanded_scatter_indices_shape[scatter_dims_seen]); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 33da323b3d74848e10fb736aa77123b0a3946556..96a0ee165d46753da4fef119e7072f66637bf2c4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -21,12 +21,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -55,7 +55,7 @@ class ShapeInference { // given input shapes. static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -73,18 +73,15 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes); + HloOpcode opcode, absl::Span operand_shapes); static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + HloOpcode opcode, absl::Span operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. static StatusOr InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions); + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. @@ -111,18 +108,32 @@ class ShapeInference { // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr InferFftShape( - const Shape& in, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + static StatusOr InferFftShape(const Shape& in, FftType fft_type, + absl::Span fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); + + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr InferAllToAllTupleShape( + absl::Span operand_shapes); + + // Infers the shape of a collective permute operation. + static StatusOr InferCollectivePermuteShape(const Shape& shape); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -131,8 +142,8 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand @@ -150,24 +161,23 @@ class ShapeInference { // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr InferReverseShape( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimensions); + static StatusOr InferReverseShape(const Shape& operand_shape, + absl::Span dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides); + static StatusOr InferSliceShape(const Shape& arg, + absl::Span starts, + absl::Span limits, + absl::Span strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -198,30 +208,30 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + const Shape& operand, absl::Span broadcast_sizes); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + static StatusOr InferReshapeShape(const Shape& operand, + absl::Span dimensions, + absl::Span new_sizes); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions); + const Shape& operand, absl::Span dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. static StatusOr InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + absl::Span arg_shapes, int64 dimension); // Infers the shape produced by a kAfterAll. Trivially this shape is always a // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes // and checking operand shapes. This method verifies that the operand shapes // are all TOKENs. static StatusOr InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes); + absl::Span arg_shapes); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that @@ -251,8 +261,7 @@ class ShapeInference { // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. static StatusOr InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply); + absl::Span arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. @@ -264,9 +273,9 @@ class ShapeInference { // with the given input shape, gather indices shape and gather dimension // numbers. static StatusOr InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, @@ -284,7 +293,7 @@ class ShapeInference { // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Helper for inferring the shape of Clamp ops. static StatusOr InferClampShape(const Shape& min, const Shape& operand, @@ -312,7 +321,7 @@ class ShapeInference { // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); }; diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index a73fa181cdd13dc7fabcdc367ae117e19bdc3e5f..864ed43118cd066f6ce14cd808b873f137b8414a 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,18 +17,17 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { // Helper that runs reduce shape inference with the input 'arg' and given // dimensions to reduce, and checks the inferred shape is as expected. The // element type here is hard-coded to F32. - void ExpectInferredReduceShape( - const Shape& expected_inferred_shape, const Shape& arg, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + void ExpectInferredReduceShape(const Shape& expected_inferred_shape, + const Shape& arg, + absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( {&arg, &f32_}, dimensions_to_reduce, to_apply); @@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const tensorflow::gtl::ArraySlice& bcast) { + const absl::Span& bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -420,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -465,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -510,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -548,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); @@ -1654,11 +1653,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + /*slice_sizes=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); @@ -1669,11 +1668,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{1}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1684,11 +1683,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{4}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1700,11 +1699,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) @@ -1717,11 +1716,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1735,11 +1734,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1749,16 +1748,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape( - f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3, 4}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{0, 1, 2, 3, 4}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) @@ -1772,11 +1770,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{0, 1, 2, 3}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/0), - /*window_bounds=*/{1, 30, 29, 28, 27})); + /*slice_sizes=*/{1, 30, 29, 28, 27})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) @@ -1787,11 +1785,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for input")) @@ -1802,11 +1800,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for gather indices")) @@ -1817,11 +1815,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather indices parameter must be an integral tensor")) @@ -1833,11 +1831,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 8, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 8, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1850,11 +1848,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1867,14 +1865,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 99, 100, 101}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 99, 100, 101}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 2 in gather op is out of bounds")) + HasSubstr("Offset dimension 2 in gather op is out of bounds")) << statusor.status(); } @@ -1883,14 +1881,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 9}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 9}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 4 in gather op is out of bounds")) + HasSubstr("Offset dimension 4 in gather op is out of bounds")) << statusor.status(); } @@ -1899,16 +1897,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{4}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr("All components of the window index in a gather op must either " - "be a output window index or explicitly elided")) + HasSubstr("All components of the offset index in a gather op must either " + "be a offset dimension or explicitly collapsed")) << statusor.status(); } @@ -1917,14 +1915,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 19}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Invalid elided_window_dims set in gather op; valid " + HasSubstr("Invalid collapsed_slice_dims set in gather op; valid " "range is [0, 5), got: 19")) << statusor.status(); } @@ -1934,16 +1932,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 3}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr( - "Repeated dimensions not allowed in elided_window_dims in gather op")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Repeated dimensions not allowed in " + "collapsed_slice_dims in gather op")) << statusor.status(); } @@ -1952,17 +1949,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " - "the bound of dimension index_vector_dim=4 of " - "gather_indices is 5. These two numbers must be equal.")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in start_index_map and " + "the bound of dimension index_vector_dim=4 of " + "start_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1971,16 +1967,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 7}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " - "[0, 5), got: 4->7")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7")) << statusor.status(); } @@ -1989,16 +1983,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + HasSubstr("Repeated dimensions are not allowed in start_index_map")) << statusor.status(); } @@ -2007,14 +2000,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{2, 1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 1, 28, 27, 26}); + /*slice_sizes=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("elided_window_dims in gather op must be sorted")) + HasSubstr("collapsed_slice_dims in gather op must be sorted")) << statusor.status(); } @@ -2023,15 +2016,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{2}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 1, 300, 26}); + /*slice_sizes=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window bound at index 3 in gather op is out of range, " - "must be within [0, 48), got 300")) + HasSubstr("Slice size at index 3 in gather op is out of range, " + "must be within [0, 48), got 300.")) << statusor.status(); } @@ -2040,16 +2033,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26}); + /*slice_sizes=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Gather op must have one window bound for every input dimension")) + HasSubstr("Gather op must have one slice size for every input dimension")) << statusor.status(); } @@ -2058,15 +2050,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26, 20}); + /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather op can only elide window indices with bound 1, " - "but bound is 29 for index 1 at position 0")) + HasSubstr("Gather op can only collapse slice dims with bound 1, " + "but bound is 29 for index 1 at position 0.")) << statusor.status(); } @@ -2074,16 +2066,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/32), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather index leaf dimension must be within [0, " - "rank(gather_indices) + 1)")) + "rank(start_indices) + 1)")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 7d7dcac10b65933d1c81b8aca77465932694bfdb..921a984589bb4fb64058a2a56adfe84fe14af69b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,20 +18,19 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::Appendf; - ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const se::Platform* platform, int device_ordinal) @@ -76,7 +75,7 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = tensorflow::strings::StrCat( + string s = absl::StrCat( "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), ", on-device shape=" + @@ -92,9 +91,9 @@ string ShapedBuffer::ToString() const { shape_str = ShapeUtil::HumanStringWithLayout(subshape); } const se::DeviceMemoryBase& memory = buffer(index); - Appendf(&s, " %s%p (%lld bytes) : %s\n", - string(index.size() * 2, ' ').c_str(), memory.opaque(), - memory.size(), shape_str.c_str()); + absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n", + string(index.size() * 2, ' '), memory.opaque(), + memory.size(), shape_str); }); return s; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 905a7e82e621f2bf4588b71be5dbab20f892cafe..e1d26da4a20c0105be304b1a34c81515fcdc6b7f 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -84,6 +84,14 @@ class ShapedBuffer { *buffers_.mutable_element(index) = buffer; } + // Sets all buffers. + // + // Precondition: buffers.shape == on_device_shape_ + void set_buffers(ShapeTree buffers) { + CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_)); + buffers_ = std::move(buffers); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index 0fc243667911651c788e3c1e5f1d39d86170f1ad..d69e6362e91e4696dab3c46d99a981c67b593a1c 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { xla::StreamExecutorMemoryAllocator allocator(platform, executors); const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); const int kDeviceOrdinal = 0; - auto scoped_buffer = tensorflow::MakeUnique( + auto scoped_buffer = absl::make_unique( shape, shape, &allocator, kDeviceOrdinal); std::unique_ptr buffer = std::move(scoped_buffer); buffer = nullptr; diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc deleted file mode 100644 index 8cbaac7b3760717bcacb57adc8782a5755c0aa6d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/source_map_util.h" - -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace source_map_util { -namespace { - -Status InvalidParameterArgumentV(const OpMetadata& op_metadata, - const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - if (!op_metadata.source_file().empty()) { - tensorflow::strings::Appendf(&message, " (%s:%d)", - op_metadata.source_file().c_str(), - op_metadata.source_line()); - } - return InvalidArgument("%s", message.c_str()); -} - -} // namespace - -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidParameterArgumentV(op_metadata, format, args); - va_end(args); - return result; -} - -Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) { - va_list args; - va_start(args, format); - if (executable != nullptr && executable->has_module()) { - const HloModule& module = executable->module(); - const HloComputation& computation = *module.entry_computation(); - HloInstruction* param = computation.parameter_instruction(parameter_number); - const OpMetadata& metadata = param->metadata(); - Status result = InvalidParameterArgumentV(metadata, format, args); - va_end(args); - return result; - } - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -} // namespace source_map_util -} // namespace xla diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index 18e2651abb1600a7b9ffb79de887b8795717e55e..c5a7e17cb44c2b3b5ef145da0d66b4b3160f9531 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -23,6 +24,19 @@ limitations under the License. namespace xla { namespace source_map_util { +// Creates an INVALID_ARGUMENT status with the given format string. +template +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const absl::FormatSpec& format, + const Args&... args) { + string message = absl::StrFormat(format, args...); + if (!op_metadata.source_file().empty()) { + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message); +} + // Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable @@ -30,17 +44,21 @@ namespace source_map_util { // // executable may be nullptr, but parameter_number should not be out of bounds // or a CHECK-failure may occur. +template Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(3, 4); - -// As above, but takes the parameter metadata directly instead of extracting it -// from the executable. -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(2, 3); + const absl::FormatSpec& format, + const Args&... args) { + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + return InvalidParameterArgument(metadata, format, args...); + } + return InvalidArgument(format, args...); +} } // namespace source_map_util } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index c0582c6a2d3a05e2ed5aead5faac54e536d350cd..ec09dff9244080d24580cad8ee2359a34a6a4f96 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/stream_pool.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -28,14 +28,20 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { // Re-use an existing stream from the pool. stream = std::move(streams_.back()); streams_.pop_back(); - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool reusing existing stream"; + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } else { + VLOG(1) << stream->DebugStreamPointers() + << " stream was not ok, StreamPool deleting"; + stream = nullptr; + } } } if (!stream) { // Create a new stream. - stream = MakeUnique(executor); + stream = absl::make_unique(executor); stream->Init(); VLOG(1) << stream->DebugStreamPointers() << " StreamPool created new stream"; diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc index aaf5c37b0d250f78cb57639255ac9b59e1b462f7..92f47579d31303b39f6f3a1859789588b586db87 100644 --- a/tensorflow/compiler/xla/service/stream_pool_test.cc +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) { EXPECT_EQ(stream2_ptr, stream3_ptr); } +TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Return the stream, but hold a handle to it. + se::Stream* stream1_ptr = stream1.get(); + stream1 = nullptr; + + // Now stream1 is back in the pool, force an error on the stream. Here we call + // a method that requires DNN support, which we know the Host platform doesn't + // support. + stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1_ptr->ok()); + + // Borrow stream2. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 7232c658b3f0687ac93a83e46a200f88bf202084..a21e586efadb85d18e88e44999283b28f7f65eac 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,7 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/notification.h" -using ::tensorflow::strings::StrCat; +using absl::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -40,18 +42,42 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr> TransferManager::TransferLiteralFromDevice( +StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { - StatusOr> ret; + StatusOr ret; + se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, - [&](StatusOr> arg) { - ret = std::move(arg); + Status s; + Literal literal(device_buffer.on_host_shape()); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return std::move(literal); +} + +Status TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + Status ret; + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + ret = status; n.Notify(); }); n.WaitForNotification(); @@ -73,25 +99,30 @@ Status TransferManager::TransferLiteralToDevice( return substream->BlockHostUntilDone(); } -StatusOr> TransferManager::TransferArrayFromDevice( +StatusOr TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { + StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - StatusOr> ret; se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferArrayFromDevice(substream, shape, source, - [&](StatusOr> arg) { - ret = std::move(arg); + Literal literal(shape); + Status s; + TransferArrayFromDevice(substream, shape, source, literal, + [&](Status status) { + s = status; n.Notify(); }); n.WaitForNotification(); - return ret; + if (!s.ok()) { + return s; + } + return std::move(literal); } Status TransferManager::TransferArrayToDevice( @@ -118,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync( if (dest.size() < GetByteSizeRequirement(on_device_shape)) { return FailedPrecondition( "Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, @@ -130,24 +161,25 @@ Status TransferManager::TransferArrayToDeviceAsync( void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - std::function>)> done) { + const MutableBorrowingLiteral& literal, std::function done) { if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); - return done(FailedPrecondition("%s", error.c_str())); + return done(FailedPrecondition("%s", error)); } if (source.size() < GetByteSizeRequirement(shape)) { return done( FailedPrecondition("Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); + return TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( @@ -171,7 +203,7 @@ void TransferManager::TransferArrayFromDevice( return NotFound( "could not find registered transfer manager for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.manager == nullptr) { @@ -222,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice( if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", source.size(), size); } stream->ThenMemcpy(destination, source, size); @@ -235,7 +267,7 @@ Status TransferManager::TransferBufferToDevice( if (destination->size() < size) { return FailedPrecondition( "Destination allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", destination->size(), size); } stream->ThenMemcpy(destination, source, size); @@ -246,9 +278,8 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { - return InvalidArgument( - "Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape)); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 82c599e482d85fc5bbe5a5a48c6c6b053186803b..f952e64af2b675b9c0f8a30e9a2bc3c855e34efa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,12 +20,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -57,8 +57,11 @@ class TransferManager { // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. - virtual StatusOr> TransferLiteralFromDevice( + virtual StatusOr TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -69,9 +72,10 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) = 0; + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -101,17 +105,17 @@ class TransferManager { // transfer an array at a known address. Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - std::function>)> done); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function done); Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - StatusOr> TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source); + StatusOr TransferArrayFromDevice(se::Stream* stream, + const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. @@ -120,13 +124,13 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( - tensorflow::gtl::ArraySlice executor) = 0; + absl::Span executor) = 0; // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the @@ -148,6 +152,26 @@ class TransferManager { const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal); + // The given ShapedBuffer holds a handle to allocated memory, but it is not + // in the general case legal to immediately copy or access that allocated + // memory because queued operations on the device may alias that memory. + // Memory ordering is enforced by the Stream's happens-before relationship + // which allows eager deallocation and reallocation of buffers host-side even + // if the device hasn't finished with them. + // + // In certain cases, it can be known that a ShapedBuffer does not have any + // conflicting accesses on the device and thus is eligible to be accessed at + // any time from the host. + // + // This function returns true if device_buffer can be accessed immediately + // without waiting for the Stream's previously enqueued items. This only + // returns true if all subbuffers in device_buffer can be accessed + // immediately. + virtual bool CanShapedBufferBeAccessedNow( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { + return false; + } + ///// // The TransferManager class also serves as a point to register objects for // the various platforms. @@ -187,8 +211,7 @@ class TransferManager { // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; private: diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 49e1f873192f800056a2272f7d4f698898b0f8a1..7c1f4b5cc67dd2a84271b4f2b8015fdb2ff6e846 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -108,7 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { } std::unique_ptr new_dot = HloInstruction::CreateDot( - dot->shape(), new_lhs, new_rhs, new_dim_numbers); + dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -177,7 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); + convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), + convolution.window(), new_dnums, convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 71e8446452f072c22bb730cbda65a1743a95cd4c..f95f982eb89d60884b652cd832dff0363372369c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -23,7 +23,7 @@ namespace xla { // HLO pass that folds transpose operators into Dot operators, where the Dot // operator is implemented by a GEMM kernel that can transpose its inputs. -class TransposeFolding : public HloPassInterface { +class TransposeFolding : public HloModulePass { public: using OperandIndices = std::vector; @@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface { explicit TransposeFolding( TransposableGemmOperandsFn transposable_gemm_operands, TransposableConvOperandsFn transposable_conv_operands); - tensorflow::StringPiece name() const override { return "transpose-folding"; } + absl::string_view name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 58f767e913fbc0023e0c45a4f0e82ecefeeef2d6..79b5c09abb355cd067a4891af558c8c44d80d88e 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -240,10 +240,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -293,10 +295,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -351,10 +355,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -415,10 +421,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0447807a41b8b32ee297e1ca94393da8c687c5e6..6fed7c76d04ad5d8236fecd07aa27f1eda221ea7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -26,17 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "])"); + return absl::StrCat("BufferAlias(", instruction_->name(), "[", + absl::StrJoin(index_, ","), "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -360,7 +360,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { } Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); points_to_set.AddPointedToBuffer( logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}), @@ -441,7 +441,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( PerInstruction* pi = PerInst(instruction); CHECK(pi->points_to_set == nullptr) << "instruction should not have been present in the map."; - auto set = MakeUnique(&instruction->shape()); + auto set = absl::make_unique(&instruction->shape()); pi->points_to_set = std::move(set); // Return *set using the iterator returned by emplace. return *pi->points_to_set; @@ -462,21 +462,20 @@ Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { return FailedPrecondition( "LogicalBuffer %s is ill-defined: instruction %s does not define a " "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + buffer.ToString(), buffer.instruction()->name()); } } if (buffer.id() < 0 || buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: invalid id %lld", - buffer.ToString().c_str(), buffer.id()); + return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d", + buffer.ToString(), buffer.id()); } if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || GetBuffer(buffer.id()).index() != buffer.index()) { return FailedPrecondition( "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", - buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str()); + buffer.ToString(), GetBuffer(buffer.id()).ToString()); } return Status::OK(); @@ -495,8 +494,7 @@ StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { return FailedPrecondition( "instruction %s does not define buffer at index {%s}", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + instruction->name(), absl::StrJoin(index, ",")); } return buffers[0]; } @@ -557,13 +555,12 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( } string TuplePointsToAnalysis::ToString() const { - string output = tensorflow::strings::Printf( - "TuplePointsToSet for module %s:\n", module_->name().c_str()); + string output = + absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name()); for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = computation == module_->entry_computation() ? "entry " : ""; - tensorflow::strings::StrAppend(&output, entry, "computation ", - computation->name(), ":\n"); + absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); @@ -575,12 +572,11 @@ string TuplePointsToAnalysis::ToString() const { } } - tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); + absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { - tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n"); + absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { - tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(), - "\n"); + absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } return output; @@ -589,20 +585,18 @@ string TuplePointsToAnalysis::ToString() const { void TuplePointsToAnalysis::InstructionToString( const HloInstruction* instruction, string* output) const { const string prefix = instruction->IsFused() ? " " : ""; - tensorflow::strings::StrAppend(output, prefix, " instruction ", - instruction->ToShortString(), ":\n"); + absl::StrAppend(output, prefix, " instruction ", + instruction->ToShortString(), ":\n"); const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement([&prefix, &output]( const ShapeIndex& index, const PointsToSet::BufferList& points_to) { - tensorflow::strings::StrAppend( - output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ", - tensorflow::str_util::Join( - points_to, ", ", - [](string* out, const LogicalBuffer* source) { - out->append(source->ToString()); - }), - "\n"); + absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ", + absl::StrJoin(points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); }); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 686bb053288fbd6a46ca50a2c65c739354fd2678..a9e8a51e0923362162c6b8a2e97fc334e56d4329 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/compactptrset.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -109,7 +110,7 @@ class PointsToSet { // Add a tuple source instruction for the given index. void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); - using BufferList = tensorflow::gtl::InlinedVector; + using BufferList = absl::InlinedVector; // Return the list of logical buffers for the subshape at index. const BufferList& element(const ShapeIndex& index) const { @@ -203,7 +204,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // logical buffer The buffer alias set is the inverse of the points-to set. // That is, LogicalBuffer B is in the points-to set of instruction I at index // N iff instruction I, index N is a BufferAlias of B. - using BufferAliasVector = tensorflow::gtl::InlinedVector; + using BufferAliasVector = absl::InlinedVector; const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; // Returns the number of logical buffers in the module @@ -226,8 +227,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // instructions produce a single buffer (the top-level buffer), some produce // no buffers (eg bitcast), and some produce more than one buffer (eg, // tuple-shaped parameters). - using BufferDefinitionVector = - tensorflow::gtl::InlinedVector; + using BufferDefinitionVector = absl::InlinedVector; const BufferDefinitionVector& GetBuffersDefinedByInstruction( const HloInstruction* instruction) const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 10d382e8abc92145c1804cbf18bbed714fa34571..e9a07b14ed685fa4388aca583395370a60176cca 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Checks that the given points-to set contains exactly (unordered) the given // LogicalBuffers. - void ExpectHasBuffers( - const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice buffers) { + void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set, + absl::Span buffers) { std::vector vec(buffers.begin(), buffers.end()); EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); } @@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // top-level buffers of the given instructions. void ExpectHasTopLevelBuffers( const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { PointsToSet::BufferList buffers; for (auto instruction : instructions) { buffers.push_back(GetBuffer(instruction, /*index=*/{})); @@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Overload which takes a set instead of a vector. void ExpectHasTopLevelBuffers( const PointsToSet::BufferSet& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { ExpectHasTopLevelBuffers( PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()), instructions); @@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // aliases which are exactly (unordered) the given instruction/index pairs. void ExpectHasBufferAliases( const HloInstruction* instruction, const ShapeIndex& index, - tensorflow::gtl::ArraySlice> - expected) { + absl::Span> expected) { const LogicalBuffer* buffer = points_to_analysis_->GetBufferDefinedAt(instruction, index) .ValueOrDie(); @@ -557,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + Literal elements[] = {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})}; + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -1066,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 750950188312c5077d487f2feef0606f07839432..e126a530234c1452bcf91f642f63d4c087935a56 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -25,12 +25,12 @@ namespace xla { // A pass which simplifies patterns of Tuple and GetTupleElement instructions in // the module. -class TupleSimplifier : public HloPassInterface { +class TupleSimplifier : public HloModulePass { public: TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} - tensorflow::StringPiece name() const override { return "tuple-simplifier"; } + absl::string_view name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 39b693872da6bd985d95c2abc9519662c838a3f5..516754e2110ee50a597818c4a8bcfbfbb76c5cec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloVerifiedTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { EXPECT_THAT(computation->root_instruction(), gte); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { EXPECT_THAT(computation->root_instruction(), element); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + Run(module, /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index 4a530bb0b20582b303f4af969514748b46fd5064..cfb0c787d09557fd1aec3517eb9698cfec323369 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -40,7 +40,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values) { + absl::Span trailing_values) { CHECK(ShapeUtil::IsTuple(input_tuple->shape())); HloComputation* computation = input_tuple->parent(); diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h index e5ff9aaa8357fe8e4777d6dee37bbec72e144c06..bc5aac09f270c01515b1f3a704af6949f24cb218 100644 --- a/tensorflow/compiler/xla/service/tuple_util.h +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -38,7 +38,7 @@ class TupleUtil { // `input_tuple`. static HloInstruction* AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values); + absl::Span trailing_values); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..541b117e0299c94de330604ec5c16e20f07c425f --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -0,0 +1,232 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { + +using absl::nullopt; +using absl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + // Look through copies. + // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU + // would run before copy insertion. + if (operand->opcode() == HloOpcode::kCopy) { + operand = operand->operand(0); + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +optional ComputeWhileLoopTripCount(HloInstruction* while_op, + int64 max_value_returned) { + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + + // The initial value of the induction variable. + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); + for (int64 trip_count = 0; trip_count != max_value_returned + 1; + ++trip_count) { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr result = evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + if (result.ValueOrDie().data() == absl::Span{false}) { + VLOG(2) << "Loop has static trip count of " << trip_count; + return trip_count; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + StatusOr indvar_next_result = evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); + if (!indvar_next_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_next_result.status(); + return nullopt; + } + indvar_iter_val = std::move(indvar_next_result).ValueOrDie(); + } + + VLOG(2) << "Loop has unknown trip count."; + return nullopt; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..bf497f4892b95c927379411468a66d8961465413 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. max_value_returned limits the number of steps that are +// evaluated while trying to brute force a loop trip count, trip counts larger +// than max_value_returned result in nullopt. +absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, + int64 max_value_returned = 128); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 62af45128ad2fb7bf886bef78ec3ab42529a181e..56145822be70f391ac3eaab5fc17db4a80e1b9cc 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( std::vector users; users.reserve(old_instr->user_count()); - c_copy(old_instr->users(), std::back_inserter(users)); + absl::c_copy(old_instr->users(), std::back_inserter(users)); for (auto* user : users) { for (int64 i = 0, e = user->operand_count(); i < e; i++) { @@ -108,10 +109,10 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { // // This will let us sink the constant into the outer while first and then // into the inner while in a single run of this pass. - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 21fb8568a84985692026e145c363500a154a1599..577bad6c7062d2ee40271e407e8eed7655fa13bf 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -50,11 +50,11 @@ namespace xla { // conditions as well. // // TODO(b/79121449): We should also sink broadcasts of constants. -class WhileLoopConstantSinking : public HloPassInterface { +class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 266039d2ff8ef4befba0d1023ac1914737207d4f..0e7667de832c54f647d071e3c9563091d0f994aa 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -206,7 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - outfeed = token[] outfeed(p_body.0) + token = token[] after-all() + outfeed = token[] outfeed(p_body.0, token) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 09ddcffb22c2184262adf87d570870ec000c0e6f..e8fe33e62659ae0fffff1ad46e8ba77f715b76b2 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -14,18 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { +using absl::InlinedVector; using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -using tensorflow::gtl::InlinedVector; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy( }; InlinedVector new_operands; - c_transform(old_instruction->operands(), std::back_inserter(new_operands), - get_new_operand); + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); HloInstruction* new_instruction = parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( @@ -109,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: + case HloOpcode::kIota: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -197,7 +199,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( op->opcode() == HloOpcode::kConstant; }; - if (!c_all_of(instruction->operands(), is_invariant)) { + if (!absl::c_all_of(instruction->operands(), is_invariant)) { continue; } @@ -257,10 +259,10 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8e6cc8787576e4f041229da5cf8dd2b09194eb2a..3031899f71e0fd77f20448d9d7489798af01615c 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that rewrites while loops to hoist loop invariant instructions in // the while body into the computation that contains the while instruction. -class WhileLoopInvariantCodeMotion : public HloPassInterface { +class WhileLoopInvariantCodeMotion : public HloModulePass { public: // If `hoist_constants` is true then constants are always hoisted out of while // loop bodies. Otherwise they are only hoisted out if they enable other @@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface { : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index ec05a74e286c89dd8db5ae07580e461938d7c087..9a74f22395099fe4f14cbc9af49814d35203df01 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,34 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; - -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} +using absl::optional; // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. @@ -72,211 +54,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { return false; } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator(/*max_loop_iterations=*/0); - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - // Tries to remove elements in a while loop's tuple that aren't used within the // loop. // @@ -459,12 +236,11 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" - << tensorflow::str_util::Join( - user->users(), ", ", - [&](string* out, const HloInstruction* instr) { - tensorflow::strings::StrAppend( - out, instr->ToString(print_no_metadata)); - }) + << absl::StrJoin(user->users(), ", ", + [&](string* out, const HloInstruction* instr) { + absl::StrAppend( + out, instr->ToString(print_no_metadata)); + }) << "}"; replacements.emplace(user, nullptr); @@ -476,7 +252,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond)); + make_while_computation_replacements(while_cond), /*extras=*/{}); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -489,7 +265,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements)); + while_body->CloneWithReplacements(std::move(while_body_replacements), + /*extras=*/{}); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -577,7 +354,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { } // Remove while loops with static trip count of 0. - optional trip_count = GetLoopTripCount(while_op); + optional trip_count = + ComputeWhileLoopTripCount(while_op, + /*max_value_returned=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 3d3e1d60f294c3a2574513c1c2f071805a341ad1..0bc5a0107bbcfb3b29a01d593fb79b89a863e49b 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -30,12 +30,10 @@ namespace xla { // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // -class WhileLoopSimplifier : public HloPassInterface { +class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} - tensorflow::StringPiece name() const override { - return "simplify-while-loops"; - } + absl::string_view name() const override { return "simplify-while-loops"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 2e1571943e537f772ee7dcd95c80ba540445b76e..1c892ba179ec67ccc9dbfe93d925551d6977ba15 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -64,10 +65,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } @@ -103,10 +102,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1ef17b9d7d2e769aadf39f8a70f78200b88e9d2c..f90ac91f9d07aded8cafccf82dae894c9a149bd1 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,15 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; static StatusOr WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { @@ -93,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { /*static*/ StatusOr WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { CHECK(ShapeUtil::IsTuple(while_instr->shape())); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); @@ -206,7 +207,7 @@ static StatusOr MakeInitTupleFromInitValues( HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); init_values_with_indvar.push_back(zero); - c_copy(init_values, std::back_inserter(init_values_with_indvar)); + absl::c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( HloInstruction::CreateTuple(init_values_with_indvar)); } @@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); - c_transform(init_values, std::back_inserter(loop_state_shape_components), - [](HloInstruction* instr) { return instr->shape(); }); + absl::c_transform(init_values, + std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index e67636d80f4b682fe1335eae535fb86105ac082b..b1c4486887ae0ddbe2ba4e79f45a265689111017 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -55,7 +55,7 @@ class WhileUtil { // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); using LoopStateTy = std::vector; using LoopBodyGeneratorTy = std::function( diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 2ccb919acf9c4e7c59a1ebaf36f42a6781068b5e..5e6941933330fde29bc9c779aae4bb3c36914660 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" @@ -206,7 +207,7 @@ ENTRY main { auto is_while = [](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kWhile; }; - EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); + EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 8763e588c484011ba2ccbc7cad8f29817347a605..87294120d51d244d9f2649cf95916f022bf829cb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -21,10 +21,10 @@ limitations under the License. // HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { -class ZeroSizedHloElimination : public HloPassInterface { +class ZeroSizedHloElimination : public HloModulePass { public: StatusOr Run(HloModule* module) override; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "zero_sized_hlo_elimination"; } }; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index caad31d6ce7ce35fa362ec364b0d7f1d95973715..d44db89d571891ecef554cd45c050017833982bb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -25,8 +25,8 @@ namespace xla { Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(other_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(other_shape), + ShapeUtil::HumanString(shape())); } shape_ = other_shape; return Status::OK(); @@ -35,8 +35,8 @@ Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*to_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(*to_shape), + ShapeUtil::HumanString(shape())); } *to_shape = shape_; return Status::OK(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index c74dd648addd70633edc2ec10a60879a00942716..df610102b4c7fa08c0b7030124939009130f89f4 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -21,16 +21,16 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -224,14 +224,13 @@ class ShapeTree { // REQUIRES: index must exist in the ShapeTree. iterator find(ShapeIndexView index) { Node* element = Lookup(index); - return iterator(&nodes_, typename std::vector::iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.begin() + (element - &nodes_[0]); + return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); - return iterator(&nodes_, - typename std::vector::const_iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); + return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } // Returns the number of leaf nodes in the tree. @@ -262,6 +261,25 @@ class ShapeTree { template Status ForEachMutableElementWithStatus(const Fn& func); + // Maps each element to generate a new tree with the same shape. + template + ShapeTree Map(const std::function& func) { + ShapeTree result(shape_storage_); + ForEachElement([&](const ShapeIndex& index, const T& t) { + *result.mutable_element(index) = func(t); + }); + return result; + } + + template + ShapeTree Map(const std::function& func) { + ShapeTree result(shape_storage_); + ForEachMutableElement([&](const ShapeIndex& index, T* t) { + *result.mutable_element(index) = func(t); + }); + return result; + } + // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at // 'target_base_index'. @@ -463,9 +481,6 @@ template ShapeTree::ShapeTree(Shape shape) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { - // The shape_ field is just used to hold the structure of the shape. - // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_storage_.get()); const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); @@ -502,9 +517,6 @@ template ShapeTree::ShapeTree(Shape shape, const T& init_value) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { - // The shape_ field is just used to hold the structure of the shape. - // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_storage_.get()); const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c4c958be4a18f23b8e34f9e619e447c6bf4334b5..c8ff55e7845785d9292516b823fb591cc28cbfad 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_tree.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -242,7 +243,7 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { ShapeTree> shape_tree{tuple_shape_}; EXPECT_EQ(shape_tree.element({2}).get(), nullptr); - *shape_tree.mutable_element({2}) = MakeUnique(42); + *shape_tree.mutable_element({2}) = absl::make_unique(42); EXPECT_EQ(*shape_tree.element({2}), 42); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 34869cc5078699603c006387161fddd4fee4a9f8..020c167ee953bbb3508bae94107de60f386602c0 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,14 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/overflow_util.h" @@ -30,26 +38,22 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { - return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", absl::StrJoin(indices_, ","), "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { @@ -91,11 +95,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, } if (ShapeUtil::IsTuple(lhs)) { - return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); + return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); + }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially // the same. @@ -107,13 +111,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, return false; } if (LayoutUtil::IsDenseArray(lhs)) { - if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { + if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } - if (!ContainersEqual(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { + if (!absl::c_equal(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions())) { VLOG(3) << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; return false; @@ -135,15 +139,15 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); } if (element_type == OPAQUE || element_type == TUPLE) { return InvalidArgument("Unsupported element type: %s", - PrimitiveType_Name(element_type).c_str()); + PrimitiveType_Name(element_type)); } Shape shape = ShapeUtil::MakeShape(element_type, dimensions); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); @@ -210,8 +214,8 @@ StatusOr MakeShapeWithLayoutInternal( return program_shape; } -/* static */ Shape ShapeUtil::MakeShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { +/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, + absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); @@ -219,21 +223,21 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ Shape ShapeUtil::MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { + PrimitiveType element_type, absl::Span dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + PrimitiveType element_type, absl::Span dimensions, int64 max_sparse_elements) { CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); @@ -252,9 +256,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return MakeShapeWithDescendingLayout(shape.element_type(), dims); } -/* static */ void ShapeUtil::PopulateShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - Shape* shape) { +/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type, + absl::Span dimensions, + Shape* shape) { shape->Clear(); shape->set_element_type(element_type); for (int64 dimension : dimensions) { @@ -264,8 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } -/* static */ Shape ShapeUtil::MakeTupleShape( - tensorflow::gtl::ArraySlice shapes) { +/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); result.mutable_tuple_shapes()->Reserve(shapes.size()); @@ -419,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - CHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + if (shape.dimensions().size() == 1) { + return shape.dimensions()[0]; + } return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies()); @@ -438,6 +444,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return count; } +/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type) { + if (shape.element_type() == primitive_type) { + return true; + } + for (const Shape& element_shape : shape.tuple_shapes()) { + if (HasPrimitiveType(element_shape, primitive_type)) { + return true; + } + } + return false; +} + /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } @@ -449,14 +468,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( namespace { // Class to memoize the computation of -// tensorflow::str_util::Lowercase(PrimitiveType_Name(p)) +// absl::AsciiStrToLower(PrimitiveType_Name(p)) // for all PrimitiveType values "p" class PrimitiveTypeNameGenerator { public: PrimitiveTypeNameGenerator() { for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = tensorflow::str_util::Lowercase( + lowercase_name_[i] = absl::AsciiStrToLower( PrimitiveType_Name(static_cast(i))); } } @@ -487,8 +506,7 @@ StatusOr StringToPrimitiveType(const string& name) { }(); auto found = name_to_type->find(name); if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", - name.c_str()); + return InvalidArgument("Invalid element type string: \"%s\".", name); } return found->second; } @@ -507,7 +525,7 @@ StatusOr StringToPrimitiveType(const string& name) { return text; } return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - tensorflow::str_util::Join(shape.dimensions(), ","), "]"); + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -543,30 +561,29 @@ StatusOr StringToPrimitiveType(const string& name) { : "(unknown)", ": ", HumanString(shape))); } - return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ", HumanString(program_shape.result())); } namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { - tensorflow::str_util::RemoveLeadingWhitespace(s); +StatusOr ParseShapeStringInternal(absl::string_view* s) { + *s = StripLeadingAsciiWhitespace(*s); - if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. + if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; bool must_end = false; while (true) { - if (tensorflow::str_util::ConsumePrefix(s, ")")) { + if (absl::ConsumePrefix(s, ")")) { break; } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", - std::string(*s).c_str()); + return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !tensorflow::str_util::ConsumePrefix(s, ","); + *s = StripLeadingAsciiWhitespace(*s); + must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } @@ -575,9 +592,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { string dimensions_string; string format_string; string layout_string; - // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so + // absl::string_view is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our StringPiece type. + // amount from our string_view type. static LazyRE2 shape_pattern = { "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); @@ -585,12 +602,12 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); - auto string_to_int64 = [&s](const string& input) -> StatusOr { + auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { int64 element; - if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + if (!absl::SimpleAtoi(input, &element)) { return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), std::string(*s).c_str()); + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, + *s); } return element; }; @@ -598,7 +615,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { auto comma_list_to_int64s = [string_to_int64](const string& input) -> StatusOr> { std::vector results; - for (const string& piece : tensorflow::str_util::Split(input, ',')) { + for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } @@ -614,7 +631,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { StringToPrimitiveType(element_type_string)); if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string.c_str()); + element_type_string); } Shape result; @@ -644,17 +661,14 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return std::move(result); } - return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(*s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); } } // namespace -/* static */ StatusOr ShapeUtil::ParseShapeString( - tensorflow::StringPiece s) { +/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", s); } return shape; } @@ -663,7 +677,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { const Shape& rhs) { CHECK(ShapeUtil::IsArray(lhs)); CHECK(ShapeUtil::IsArray(rhs)); - return ContainersEqual(lhs.dimensions(), rhs.dimensions()); + return absl::c_equal(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { @@ -677,8 +691,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return IsArray(rhs) && SameDimensions(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); + absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); } else { // Opaque, token, etc types are vacuously compatible. return lhs.element_type() == rhs.element_type(); @@ -692,8 +706,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { CompatibleIgnoringElementType(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); + absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); } else { // Opaque, token, etc types are vacuously compatible. return lhs.element_type() == rhs.element_type(); @@ -792,7 +806,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - tensorflow::gtl::ArraySlice padded_dimensions = + absl::Span padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { CHECK_EQ(Rank(shape), padded_dimensions.size()); @@ -819,7 +833,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { const Shape& shape) { if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + shape.ShortDebugString()); } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { @@ -842,21 +856,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } return Status::OK(); } if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%lld " + "shape's rank is mismatched with dimension count; rank=%d " "dimensions_size=%d", Rank(shape), shape.dimensions_size()); } @@ -864,9 +878,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( - "shape's dimensions must not be < 0; dimension at index %lld was " - "%lld", - i, dimension); + "shape's dimensions must not be < 0; dimension at index %d was %d", i, + dimension); } } @@ -931,7 +944,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape_size < 0) { return InvalidArgument("Shape %s size may overflow int64.", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } VLOG(3) << "Shape size is valid: " << shape_size; @@ -991,7 +1004,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", - index.ToString().c_str(), shape.DebugString().c_str()); + index.ToString(), shape.DebugString()); } return_shape = &return_shape->tuple_shapes(i); } @@ -1014,12 +1027,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + if (!IsTuple(shape)) { + return 1; + } int64 count = 0; - ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - ++count; - } - }); + for (const Shape& subshape : shape.tuple_shapes()) { + count += GetLeafCount(subshape); + } return count; } @@ -1036,7 +1050,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)); - return ArrayContains(AsInt64Slice(shape.dimensions()), 1); + return absl::c_linear_search(shape.dimensions(), 1); } namespace { @@ -1116,7 +1130,7 @@ Status ForEachMutableSubshapeHelper( } /* static */ Shape ShapeUtil::PermuteDimensions( - tensorflow::gtl::ArraySlice permutation, const Shape& shape) { + absl::Span permutation, const Shape& shape) { Shape new_shape = shape; new_shape.clear_dimensions(); for (auto dim : Permute(permutation, shape.dimensions())) { @@ -1171,8 +1185,7 @@ Status ForEachMutableSubshapeHelper( CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) << "shape=" << HumanStringWithLayout(shape) << ", new_shape=" << HumanStringWithLayout(new_shape) - << ", permutation={" << tensorflow::str_util::Join(permutation, ",") - << "}"; + << ", permutation={" << absl::StrJoin(permutation, ",") << "}"; } return new_shape; } @@ -1261,7 +1274,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping) { + absl::Span dimension_mapping) { CHECK(LayoutUtil::HasLayout(input_shape) && LayoutUtil::HasLayout(output_shape)); @@ -1288,7 +1301,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // apply(input_dimensions, I) = // apply((dimension_mapping * output_dimensions), I) // input_dimensions = dimension_mapping * output_dimensions - return ContainersEqual( + return absl::c_equal( ComposePermutations(dimension_mapping, AsInt64Slice(output_shape.layout().minor_to_major())), input_shape.layout().minor_to_major()); @@ -1459,7 +1472,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, check_input_unit_indices(output_shape, input_shape); } -/* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( +/* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { CHECK(IsArray(input_shape)); CHECK(IsArray(output_shape)); @@ -1498,7 +1511,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (input_dimension_product < output_dimension_product || j == output_rank) { if (i == input_rank) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } dimension_to_alignment_index[i] = alignment.size() - 1; input_dimension_product *= input_shape.dimensions(i); @@ -1509,7 +1522,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } } if (input_dimension_product != output_dimension_product) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } // We also need to store an end element so that we know where the last // alignment part ends. @@ -1553,7 +1566,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; ++i, ++j) { if (i == input_rank) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } // Skip trivial dimensions with a bound of 1. if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { @@ -1566,7 +1579,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (dimension_to_alignment_index[input_dimension_numbers[i]] != current_alignment_index || input_dimension_numbers[i] > current_dimension_number) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } current_dimension_number = input_dimension_numbers[i]; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index d6f17fc965d24bbbbd083b8dd0ec11a59e49ed4e..d8bb27beae64bb665c79c2cd7134f613495529cc 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,10 @@ limitations under the License. #include #include +#include "absl/base/macros.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,9 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -74,7 +75,7 @@ class ShapeIndex { // push_front is O(n^2), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } - using container_type = tensorflow::gtl::InlinedVector; + using container_type = absl::InlinedVector; container_type::const_iterator begin() const { return indices_.begin(); } container_type::const_iterator end() const { return indices_.end(); } @@ -131,12 +132,12 @@ class ShapeIndexView { } ShapeIndexView ConsumeFront() const { ShapeIndexView result = *this; - result.indices_.pop_front(); + result.indices_.remove_prefix(1); return result; } ShapeIndexView ConsumeBack() const { ShapeIndexView result = *this; - result.indices_.pop_back(); + result.indices_.remove_suffix(1); return result; } ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } @@ -147,7 +148,7 @@ class ShapeIndexView { string ToString() const; private: - tensorflow::gtl::ArraySlice indices_; + absl::Span indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -180,6 +181,10 @@ class ShapeUtil { // As ElementsIn(), but recurses through tuples. static int64 ElementsInRecursive(const Shape& shape); + // Returns true if shape has the primitive type, recurses through tuples. + static bool HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type); + // Returns true if 'shape' is an array with zero elements. static bool IsZeroElementArray(const Shape& shape); @@ -228,7 +233,7 @@ class ShapeUtil { // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. - static StatusOr ParseShapeString(tensorflow::StringPiece s); + static StatusOr ParseShapeString(absl::string_view s); // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. @@ -328,7 +333,7 @@ class ShapeUtil { static Shape ChangeElementType(const Shape& original, PrimitiveType type); // Creates a tuple shape from a slice of element shapes within the tuple. - static Shape MakeTupleShape(tensorflow::gtl::ArraySlice shapes); + static Shape MakeTupleShape(absl::Span shapes); // Creates an opaque shape. These are generally used for threading a context // into a custom operation. @@ -355,31 +360,29 @@ class ShapeUtil { // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions template - static Shape MakeShapeWithType( - tensorflow::gtl::ArraySlice dimensions) { + static Shape MakeShapeWithType(absl::Span dimensions) { return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dimensions); } // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). - static Shape MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major); + static Shape MakeShapeWithLayout(PrimitiveType element_type, + absl::Span dimensions, + absl::Span minor_to_major); - static Shape MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - int64 max_sparse_elements); + static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + absl::Span dimensions, + int64 max_sparse_elements); // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( - PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + PrimitiveType element_type, absl::Span dimensions); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions @@ -391,8 +394,7 @@ class ShapeUtil { // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - Shape* shape); + absl::Span dimensions, Shape* shape); // Validates that the provided shape satisfies invariants. static Status ValidateShape(const Shape& shape); @@ -478,8 +480,7 @@ class ShapeUtil { // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. - // - // DEPRECATED: Use Equal() instead. + ABSL_DEPRECATED("Use Equal() instead.") static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); @@ -539,7 +540,7 @@ class ShapeUtil { // !HasLayout(shape) || // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), // InversePermutation(permutation)). - static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, + static Shape PermuteDimensions(absl::Span permutation, const Shape& shape); // If we can go from `shape_pre` to `shape_post` by merely inserting or @@ -580,9 +581,9 @@ class ShapeUtil { // to its input and thus may be replaced with a bitcast. // // Precondition: Both input_shape and output_shape have explicit layouts. - static bool TransposeIsBitcast( - const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping); + static bool TransposeIsBitcast(const Shape& input_shape, + const Shape& output_shape, + absl::Span dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. @@ -597,8 +598,8 @@ class ShapeUtil { // layout). The layout of 'input_shape' is kept fixed. Returns // 'output_shape_with_layout' if such a layout can be found, and an error // otherwise. - static tensorflow::gtl::optional AlignLayouts( - const Shape& input_shape, const Shape& output_shape); + static absl::optional AlignLayouts(const Shape& input_shape, + const Shape& output_shape); // Returns a shape with the given dimension deleted. // For example: @@ -621,12 +622,12 @@ class ShapeUtil { // continue, or false otherwise. // // visitor_function must be a callable of type - // StatusOr(ArraySlice) or compatible. + // StatusOr(Span) or compatible. template static Status ForEachIndexWithStatus(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { return ForEachIndexInternal(shape, base, count, incr, visitor_function); } @@ -648,13 +649,12 @@ class ShapeUtil { } template - static void ForEachIndex(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + static void ForEachIndex(const Shape& shape, absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { ForEachIndexWithStatus(shape, base, count, incr, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -676,7 +676,7 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { ForEachIndexWithStatus(shape, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -687,18 +687,18 @@ class ShapeUtil { // matter. // // visitor_function must be a callable of type - // void(ArraySlice) or compatible. + // void(Span) or compatible. template static void ForEachIndexParallel(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { // The parallel version of ForEachIndexInternal can never fail. CHECK(ForEachIndexInternal( shape, base, count, incr, - [&visitor_function](tensorflow::gtl::ArraySlice indexes) - -> StatusOr { + [&visitor_function]( + absl::Span indexes) -> StatusOr { visitor_function(indexes); return true; }, @@ -720,9 +720,9 @@ class ShapeUtil { template static Status ForEachIndexInternal(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function, bool parallel = false) { if (ShapeUtil::IsZeroElementArray(shape)) { @@ -737,13 +737,13 @@ class ShapeUtil { int64 n = -1; std::vector indexes(base.begin(), base.end()); const int kNumThreads = tensorflow::port::NumSchedulableCPUs(); - tensorflow::gtl::optional pool; + absl::optional pool; if (parallel) { pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); } while (n < rank) { - if (pool != tensorflow::gtl::nullopt) { + if (pool != absl::nullopt) { pool->Schedule( [indexes, &visitor_function] { visitor_function(indexes); }); } else { diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index e5dd62ae9a3dd9b961a7ae03a99c19220dbd43e7..c622ecdca1fd66604d1a6ceaf705f2e70edaee55 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -23,8 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } +TEST(ShapeUtilTest, HasPrimitiveType) { + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}), + S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}), + S16)); +} + TEST(ShapeUtilTest, IsZeroElementArray) { EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); @@ -705,11 +721,10 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = - [&invocations](tensorflow::gtl::ArraySlice indexes) { - invocations++; - return true; - }; + auto increment_func = [&invocations](absl::Span indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -726,8 +741,7 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) { // Increments at every invocation. int invocations = 0; auto increment_func = - [&invocations]( - tensorflow::gtl::ArraySlice indexes) -> StatusOr { + [&invocations](absl::Span indexes) -> StatusOr { if (++invocations == 5) { return Unimplemented("Cannot increment beyond 5."); } @@ -748,7 +762,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) { Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); int64 output[10][10]; int init = 5; - auto set_func = [&](tensorflow::gtl::ArraySlice indexes) { + auto set_func = [&](absl::Span indexes) { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; }; @@ -849,13 +863,13 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { std::iota(layout.begin(), layout.end(), 0); do { Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout); - SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s))); + SCOPED_TRACE(absl::StrCat("s=", ShapeUtil::HumanString(s))); std::vector permutation(3); std::iota(permutation.begin(), permutation.end(), 0); do { - SCOPED_TRACE(tensorflow::strings::StrCat( - "permutation=", tensorflow::str_util::Join(permutation, ","))); + SCOPED_TRACE( + absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); // TransposeIsBitcast takes the inverse of the permutation that // PermuteDimensions takes. diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index 31844abd89a020c87c403353374a80fb639a3244..1c135dda864b3060b8bdc6369f18268d7c5c7f9e 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -33,7 +33,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, } SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices) + absl::Span indices) : SparseIndexArray(max_indices, rank, std::vector(indices.begin(), indices.end())) {} @@ -48,25 +48,24 @@ int64 SparseIndexArray::index_count() const { return indices_.size() / rank_; } -tensorflow::gtl::ArraySlice SparseIndexArray::At( +absl::Span SparseIndexArray::At( int64 sparse_element_number) const { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::ArraySlice( + return absl::Span( indices_.data() + rank_ * sparse_element_number, rank_); } -tensorflow::gtl::MutableArraySlice SparseIndexArray::At( - int64 sparse_element_number) { +absl::Span SparseIndexArray::At(int64 sparse_element_number) { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::MutableArraySlice( - indices_.data() + rank_ * sparse_element_number, rank_); + return absl::Span(indices_.data() + rank_ * sparse_element_number, + rank_); } -void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { +void SparseIndexArray::Append(absl::Span index) { CHECK_GT(rank_, 0); CHECK_EQ(index.size(), rank_); indices_.insert(indices_.end(), index.begin(), index.end()); @@ -90,12 +89,12 @@ bool SparseIndexArray::Validate(const Shape& shape) const { if (num_indices < 2) { return true; } - tensorflow::gtl::ArraySlice last = At(0); + absl::Span last = At(0); if (!IndexUtil::IndexInBounds(shape, last)) { return false; } for (int64 n = 1; n < num_indices; ++n) { - tensorflow::gtl::ArraySlice next = At(n); + absl::Span next = At(n); if (!IndexUtil::IndexInBounds(shape, next)) { return false; } diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index f2ce22d6721ff8da46f741ccedc2a63dea5994c8..a96d483462efd77ae4761541e8c79b2c84fa49f3 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -20,10 +20,11 @@ limitations under the License. #include +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -64,7 +65,7 @@ class SparseIndexArray { SparseIndexArray(int64 max_indices, int64 rank, std::vector indices = {}); SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices); + absl::Span indices); // Returns the number of elements represented by the indices stored in the // array. @@ -72,12 +73,12 @@ class SparseIndexArray { // Returns a slice that refers to the given sparse index number. The argument // must be in the range [0, element_count()). - tensorflow::gtl::ArraySlice At(int64 sparse_element_number) const; - tensorflow::gtl::MutableArraySlice At(int64 sparse_element_number); + absl::Span At(int64 sparse_element_number) const; + absl::Span At(int64 sparse_element_number); // Adds the given index at the end of the array. The new size of the // SparseIndexArray must not exceed `max_indices`. - void Append(tensorflow::gtl::ArraySlice index); + void Append(absl::Span index); // Removes all indices from the array. void Clear(); @@ -95,8 +96,8 @@ class SparseIndexArray { int64 max_indices() const { return max_indices_; } // Returns a pointer to the int64 array that holds the sparse indices. - tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } - tensorflow::gtl::ArraySlice data() const { return indices_; } + absl::Span mutable_data() { return absl::MakeSpan(indices_); } + absl::Span data() const { return indices_; } // Sorts this sparse index array along with the set of corresponding values. // The indices and values are sorted in the lexicographic order of the @@ -114,7 +115,7 @@ class SparseIndexArray { // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; // template - void SortWithValues(tensorflow::gtl::MutableArraySlice values); + void SortWithValues(absl::Span values); private: std::vector indices_; @@ -123,8 +124,7 @@ class SparseIndexArray { }; template -void SparseIndexArray::SortWithValues( - tensorflow::gtl::MutableArraySlice values) { +void SparseIndexArray::SortWithValues(absl::Span values) { int64 num_elements = index_count(); CHECK_EQ(values.size(), num_elements); std::vector sort_order; @@ -139,7 +139,7 @@ void SparseIndexArray::SortWithValues( // Reorder the array elements according to sort_order. Work through the array // and follow cycles so we can do the reorder in-place. - tensorflow::gtl::InlinedVector saved_index(rank()); + absl::InlinedVector saved_index(rank()); for (int64 i = 0; i < num_elements; ++i) { // sort_order[i] == -1 indicates the element has already been copied. if (sort_order[i] < 0) { diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc index 7377f88958dcb7daf3d3f4f0e07966fdc9294580..e54057c4007078c76b79fe44d5706665e266c083 100644 --- a/tensorflow/compiler/xla/sparse_index_array_test.cc +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -33,7 +33,7 @@ TEST(SparseIndexArrayTest, Sort) { std::vector values = { 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, }; - a.SortWithValues(&values); + a.SortWithValues(absl::MakeSpan(values)); ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8})); ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index a6b1f9004f096abb3b01d315938b0a23bea1ca48..b88fe367d7416a26c1147fd5e10fb20772814fe5 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -17,9 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stacktrace.h" @@ -37,8 +36,7 @@ static void LogError(const Status& status, const char* filename, int line, if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) { string stack_trace; if (should_log_stack_trace) { - stack_trace = - tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace()); + stack_trace = absl::StrCat("\n", tensorflow::CurrentStackTrace()); } switch (log_severity) { case tensorflow::INFO: @@ -142,17 +140,15 @@ Status MakeErrorStream::Impl::GetStatus() { is_done_ = true; const string& stream_str = stream_.str(); - const string str = - prior_message_handling_ == kAppendToPriorMessage - ? tensorflow::strings::StrCat(prior_message_, stream_str) - : tensorflow::strings::StrCat(stream_str, prior_message_); + const string str = prior_message_handling_ == kAppendToPriorMessage + ? absl::StrCat(prior_message_, stream_str) + : absl::StrCat(stream_str, prior_message_); if (TF_PREDICT_FALSE(str.empty())) { - return MakeError(file_, line_, code_, - tensorflow::strings::StrCat( - str, "Error without message at ", file_, ":", line_), - true /* should_log */, - tensorflow::ERROR /* log_severity */, - should_log_stack_trace_); + return MakeError( + file_, line_, code_, + absl::StrCat(str, "Error without message at ", file_, ":", line_), + true /* should_log */, tensorflow::ERROR /* log_severity */, + should_log_stack_trace_); } else { return MakeError(file_, line_, code_, str, should_log_, log_severity_, should_log_stack_trace_); diff --git a/tensorflow/compiler/xla/test.h b/tensorflow/compiler/xla/test.h index 87a8c5f3a528289d47c1729ae6719aae47037c36..a657554dc2fd4fd1838639cac011bc0bb8b3d1eb 100644 --- a/tensorflow/compiler/xla/test.h +++ b/tensorflow/compiler/xla/test.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPLIER_XLA_TEST_H_ -#define TENSORFLOW_COMPLIER_XLA_TEST_H_ +#ifndef TENSORFLOW_COMPILER_XLA_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_TEST_H_ // This header includes gmock.h and enables the use of gmock matchers in tests // in third_party/tensorflow/compiler/xla. @@ -45,4 +45,4 @@ limitations under the License. #include "tensorflow/core/platform/test.h" -#endif // TENSORFLOW_COMPLIER_XLA_TEST_H_ +#endif // TENSORFLOW_COMPILER_XLA_TEST_H_ diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 8918350135fbb86973b228b35f5873fea8695b2f..3ede5e6e38a7a9e922fc0744f014c395dbd2324c 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 42d52aee780e2aade0f2ed3597e653567b8da49b..f474ecb18c75327edec449433c36a91d8ac7de83 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites" load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -43,6 +47,7 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], alwayslink = True, ) @@ -68,14 +73,14 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -98,6 +103,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -113,7 +121,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", @@ -127,6 +134,10 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -143,6 +154,27 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "hlo_verified_test_base_test", + srcs = ["hlo_verified_test_base_test.cc"], + deps = [ + ":hlo_test_base", + ":hlo_verified_test_base", + ":test_macros_cpu", + ":test_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", "//tensorflow/core:test", ], ) @@ -187,7 +219,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -201,6 +232,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -274,6 +308,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -385,6 +421,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -551,6 +589,8 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -567,8 +607,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -591,8 +630,8 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -614,12 +653,11 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -633,6 +671,7 @@ xla_test( ], shard_count = 48, tags = [ + "broken", "manual", "notap", ], @@ -665,6 +704,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -683,7 +723,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -691,6 +730,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -709,6 +749,19 @@ xla_test( ], ) +xla_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + # Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", @@ -727,7 +780,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -735,6 +787,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -798,6 +851,7 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -809,7 +863,10 @@ xla_test( timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -819,7 +876,10 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -870,6 +930,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -903,6 +964,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -979,6 +1041,10 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1052,6 +1118,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1087,7 +1154,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1105,6 +1171,9 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1133,6 +1202,9 @@ xla_test_library( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1140,6 +1212,7 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], + shard_count = 20, tags = [ "enable_for_xla_interpreter", "optonly", @@ -1195,6 +1268,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1205,12 +1279,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1221,12 +1295,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1270,6 +1344,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1335,6 +1410,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1385,7 +1461,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1396,6 +1471,9 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1410,11 +1488,11 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1428,14 +1506,12 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -1445,7 +1521,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1465,6 +1541,8 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1525,17 +1603,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1600,8 +1677,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1614,12 +1691,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1632,7 +1710,6 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -1643,6 +1720,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1736,13 +1814,14 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", @@ -1757,6 +1836,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -1777,6 +1857,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1790,15 +1871,11 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1808,6 +1885,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -1815,18 +1894,12 @@ xla_test( name = "multioutput_fusion_test", srcs = ["multioutput_fusion_test.cc"], deps = [ - "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1834,6 +1907,9 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1860,7 +1936,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1868,6 +1943,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:optional", ], ) @@ -1962,16 +2038,15 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1994,6 +2069,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -2035,6 +2111,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -2043,7 +2120,7 @@ tf_cc_test( name = "sample_file_test", srcs = ["sample_file_test.cc"], data = ["isolated_convolution.hlo"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":hlo_test_base", "//tensorflow/compiler/xla:test", @@ -2061,6 +2138,8 @@ tf_cc_test( xla_test( name = "test_utils_test", srcs = ["test_utils_test.cc"], + # There is nothing backend specific in this test, so just pick an arbitrary backend. + backends = ["cpu"], deps = [ ":local_client_test_base", ":test_utils", @@ -2069,6 +2148,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", "//tensorflow/core:test", ], ) @@ -2076,19 +2156,33 @@ xla_test( xla_test( name = "iota_test", srcs = ["iota_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], + shard_count = 30, tags = [ "enable_for_xla_interpreter", + # Require optimized builds, iota_test_cpu is very slow in fastbuild. + "optonly", ], deps = [ ":client_library_test_base", - ":literal_test_util", ":xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multiple_devices_on_host_test", + srcs = ["multiple_devices_on_host_test.cc"], + args = ["--xla_force_host_platform_device_count=4"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 74f2e36f826cd82ce4015df857f3de67950beaeb..c257566fb218d4769aec0c793efb9256b023b7ea 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -225,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{1, 0x7FFFFFFFFFFFFFFLL, @@ -239,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Add(lhs_param, rhs_param); @@ -265,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{-1, 0, @@ -278,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Sub(lhs_param, rhs_param); @@ -293,6 +294,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); } +XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { + XlaBuilder b(TestName()); + + std::vector lhs{static_cast(0x8000000000000000ULL)}; + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); + + std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); + + Lt(lhs_param, rhs_param); + + ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); XlaBuilder builder(TestName()); @@ -303,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + Literal a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); - auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); + auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); auto b_param = ConstantR1(&builder, b_values); auto sum1 = Add(a_constant, b_constant); @@ -411,7 +428,65 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { +class IntegerDivideOpTest : public ArrayElementwiseOpTest { + protected: + template + void TestDivRem(absl::Span dividends, absl::Span divisors, + absl::Span quotients, + absl::Span remainders) { + { + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + Div(dividend, divisor); + + ComputeAndCompareR1(&builder, quotients, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + XlaBuilder builder(TestName()); + XlaOp dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + Div(dividend, ConstantR1(&builder, divisors)); + + ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); + } + + { + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + Rem(dividend, divisor); + + ComputeAndCompareR1(&builder, remainders, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + XlaBuilder builder(TestName()); + XlaOp dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + Rem(dividend, ConstantR1(&builder, divisors)); + + ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); + } + } +}; + +XLA_TEST_F(IntegerDivideOpTest, DivS32s) { // clang-format off // Some interesting values to test. std::vector vals = { @@ -435,58 +510,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } } - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Div(dividend, divisor); - - ComputeAndCompareR1(&builder, quotients, - {dividend_data.get(), divisor_data.get()}); - } - - // Test with a compile-time constant divisor. - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - Div(dividend, ConstantR1(&builder, divisors)); - - ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); - } - - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Rem(dividend, divisor); - - ComputeAndCompareR1(&builder, remainders, - {dividend_data.get(), divisor_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); +} - // Test with a compile-time constant divisor. - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - Rem(dividend, ConstantR1(&builder, divisors)); +XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) { + std::vector dividends = {5, INT32_MIN}, divisors = {0, -1}, + quotients = {-1, INT32_MIN}, remainders = {5, 0}; - ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); } -XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { +XLA_TEST_F(IntegerDivideOpTest, DivU32s) { // clang-format off // Some interesting values to test. std::vector vals = { @@ -506,53 +540,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } } - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Div(dividend, divisor); - - ComputeAndCompareR1(&builder, quotients, - {dividend_data.get(), divisor_data.get()}); - } - - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - Div(dividend, ConstantR1(&builder, divisors)); - - ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); - } - - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Rem(dividend, divisor); - - ComputeAndCompareR1(&builder, remainders, - {dividend_data.get(), divisor_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); +} - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - Rem(dividend, ConstantR1(&builder, divisors)); +XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) { + std::vector dividends = {5}, divisors = {0}, quotients = {-1}, + remainders = {5}; - ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { @@ -1426,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + Literal param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto sum = ConstantR0(&b, 0.0f); - auto param = Parameter(&b, 0, param_literal->shape(), "param"); + auto param = Parameter(&b, 0, param_literal.shape(), "param"); for (float exponent : exponents) { sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } @@ -1454,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Pow(Exp(param0), param1); std::vector expected(values0.size()); @@ -1479,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Log(Pow(param0, param1)); std::vector expected(values0.size()); @@ -1504,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); @@ -1529,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Div(param0, Exp(param1)); std::vector expected(values0.size()); @@ -1555,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + client_->TransferToServer(literal2).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(Div(param0, param1), param2); std::vector expected(values0.size()); @@ -1587,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Div(param1, param2)); std::vector expected(values0.size()); @@ -1620,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); @@ -1654,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); + Literal literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = - client_->TransferToServer(*literal3).ConsumeValueOrDie(); + client_->TransferToServer(literal3).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); - auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); @@ -2100,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, @@ -2122,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); Array3D expected(0, 7, 0); @@ -2144,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); - auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, @@ -2210,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Tanh(input); ComputeAndCompareR1( @@ -2243,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2256,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Exp(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::exp(input_literal->Get({i}))); + expected_result.push_back(std::exp(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2277,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2294,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Log(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::log(input_literal->Get({i}))); + expected_result.push_back(std::log(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2469,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{true, true}, {true, false}}), + LiteralUtil::CreateR2({{true, false}, {false, false}})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { @@ -2825,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = ConstantLiteral(&builder, *a_literal); + Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + auto a = ConstantLiteral(&builder, a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2890,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); auto y_literal = LiteralUtil::CreateR1({4, 5}); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); - auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); + auto y = Parameter(&builder, 1, y_literal.shape(), "y"); auto slice = Slice(x, {1}, {2}, {1}); Sub(slice, y); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 24b17b71007a1872462bed1f6b86ae1a5bb9922c..bc2ba151a38f1ab000b342dcd4bdd8f53d9ce9a9 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); + input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { @@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { @@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), - LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 0.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR3FromArray3D( - {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), - LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}), + LiteralUtil::CreateR1(std::vector(1, 15.0f)), + LiteralUtil::CreateR1(std::vector(1, 125.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - LiteralUtil::CreateR1({0, 0}).get(), - LiteralUtil::CreateR1({16, 20}).get()}); + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}), + LiteralUtil::CreateR1({0, 0}), + LiteralUtil::CreateR1({16, 20})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } struct BatchNormTestParam { @@ -382,7 +377,7 @@ struct BatchNormTestParam { friend ::std::ostream& operator<<(::std::ostream& os, const BatchNormTestParam& p) { - os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, "; os << "feature_index=" << p.feature_index << ", "; os << "random_value_mean=" << p.random_value_mean << ", "; os << "random_value_var=" << p.random_value_var; @@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); - auto expected = LiteralUtil::MakeTuple( - {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), - LiteralUtil::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_normalized, LiteralUtil::CreateR1(mean), + LiteralUtil::CreateR1(var)}); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); @@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, *expected, + &builder, expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } @@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean"); auto variance_activations = - Parameter(&builder, 4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal.shape(), "variance"); Array4D expected = normalized; std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr variance_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); BatchNormInference(input_activations, scale_activations, offset_activations, mean_activations, variance_activations, epsilon, @@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = LiteralUtil::CreateR4FromArray4D(grad_output_array); - auto input_parameter = - Parameter(&builder, 0, input_literal->shape(), "input"); - auto scale_parameter = - Parameter(&builder, 1, scale_literal->shape(), "scale"); - auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); - auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); + auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input"); + auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance"); auto grad_output_parameter = - Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal.shape(), "grad_output"); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr var_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); std::unique_ptr grad_output_data = - client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + client_->TransferToServer(grad_output_literal).ConsumeValueOrDie(); BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, grad_output_parameter, epsilon, feature_index); - auto expected = - LiteralUtil::MakeTuple({expected_grad_activation.get(), - LiteralUtil::CreateR1(grad_scale).get(), - LiteralUtil::CreateR1(grad_offset).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_grad_activation, LiteralUtil::CreateR1(grad_scale), + LiteralUtil::CreateR1(grad_offset)}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 6c20f654fe3df6a28e9633cd832c11b487894bad..e9728e636f0ee032416b2da17a3ea83c5bb18083 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -65,7 +65,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { Log(x); ComputeAndCompareR0(&builder, static_cast(1.387f), {}, - error_spec_); + ErrorSpec(0.01, 0.01)); } XLA_TEST_F(Bfloat16Test, NegateScalarF16) { @@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, - {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) - .get(), + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}), LiteralUtil::CreateR1( - {static_cast(4), static_cast(5)}) - .get(), + {static_cast(4), static_cast(5)}), LiteralUtil::CreateR1( - {static_cast(5), static_cast(5)}) - .get()}); + {static_cast(5), static_cast(5)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, - {{static_cast(3.f)}, {static_cast(3.f)}}}}) - .get(), + {{static_cast(3.f)}, {static_cast(3.f)}}}}), LiteralUtil::CreateR1( - {static_cast(0), static_cast(0)}) - .get(), + {static_cast(0), static_cast(0)}), LiteralUtil::CreateR1( - {static_cast(16), static_cast(20)}) - .get()}); + {static_cast(16), static_cast(20)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 1d28e85b16596b0ec2717138fb2081878203e8b2..dde19fb65d65064c9452a6ac49c70e20cf113336 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -53,29 +53,31 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { } } - std::unique_ptr MakeR3Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r3_shape, - Array3D* r3_array, float start, float end, int seed) { + std::unique_ptr MakeR3Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r3_shape, + Array3D* r3_array, float start, + float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = - client_->TransferToServer(*r3_data).ConsumeValueOrDie(); + client_->TransferToServer(r3_data).ConsumeValueOrDie(); return r3_global_data; } - std::unique_ptr MakeR2Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r2_shape, - Array2D* r2_array, float start, float end, int seed) { + std::unique_ptr MakeR2Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r2_shape, + Array2D* r2_array, float start, + float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = - client_->TransferToServer(*r2_data).ConsumeValueOrDie(); + client_->TransferToServer(r2_data).ConsumeValueOrDie(); return r2_global_data; } @@ -291,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -299,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R3ImplicitBroadcastSpec { @@ -348,7 +350,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); - auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + auto Each = ([&](absl::Span indices, float* value) { float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], indices[1] % spec.input_bounds[1], indices[2] % spec.input_bounds[2]); @@ -368,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, - {r3_implicit_global_data.get(), r3_global_data.get()}, + &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()}, ErrorSpec(1e-7, 1e-7)); } @@ -393,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + &b, LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R2ImplicitBroadcastSpec { @@ -616,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, + &builder, expected, {r2_implicit_global_data1.get(), r2_global_data.get(), r2_implicit_global_data2.get()}, ErrorSpec(1e-6, 1e-6)); @@ -628,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); auto expected = LiteralUtil::CreateR3( {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { @@ -695,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -707,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { @@ -728,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { @@ -737,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index c7b94b5bbaaa512ad36056f9e68a87cc706c24b1..9966e4606ef7f104487182e0240e64e4c9e4d834 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0(42.0), result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), result, error_spec_)); } @@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralSlice(*result, {0}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + LiteralSlice(result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralSlice(*result, {1}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + LiteralSlice(result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), result, error_spec_)); } @@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), result, error_spec_)); } @@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_)); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 53f2c3bfbfce9585cb68f103a495ce2f1ad8432e..05d4d04034bf50c8bb840e59b28a590fce048c19 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -3,256 +3,266 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) all_backends = ["cpu", "gpu"] + plugins.keys() def filter_backends(backends): - """Removes "gpu" from a backend list if CUDA is not enabled. - - This allows us to simply hardcode lists including "gpu" here and in the - BUILD file, without causing failures when CUDA isn't enabled.' - - Args: - backends: A list of backends to filter. - - Returns: - The filtered list of backends. - """ - if cuda_is_configured(): - return backends - else: - return [backend for backend in backends if backend != "gpu"] - - -def xla_test(name, - srcs, - deps, - xla_test_library_deps=[], - backends=[], - blacklisted_backends=[], - args=[], - tags=[], - copts=[], - data=[], - backend_tags={}, - backend_args={}, - **kwargs): - """Generates cc_test targets for the given XLA backends. - - This rule generates a cc_test target for one or more XLA backends and also a - platform-agnostic cc_library rule. The arguments are identical to cc_test with - two additions: 'backends' and 'backend_args'. 'backends' specifies the - backends to generate tests for ("cpu", "gpu"), and - 'backend_args'/'backend_tags' specifies backend-specific args parameters to - use when generating the cc_test. - - The name of the cc_tests are the provided name argument with the backend name - appended, and the cc_library target name is the provided name argument with - "_lib" appended. For example, if name parameter is "foo_test", then the cpu - test target will be "foo_test_cpu" and the cc_library target is "foo_lib". - - The cc_library target can be used to link with other plugins outside of - xla_test. - - The build rule also defines a test suite ${name} which includes the tests for - each of the supported backends. - - Each generated cc_test target has a tag indicating which backend the test is - for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These - tags can be used to gather tests for a particular backend into a test_suite. - - Examples: - - # Generates the targets: foo_test_cpu and foo_test_gpu. - xla_test( - name = "foo_test", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) + """Removes "gpu" from a backend list if CUDA is not enabled. - # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu - # includes the additional arg "--special_cpu_flag". - xla_test( - name = "bar_test", - srcs = ["bar_test.cc"], - backends = ["cpu", "gpu"], - backend_args = {"cpu": ["--special_cpu_flag"]} - deps = [...], - ) + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' - The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} - to the value 1 where ${BACKEND} is the uppercase name of the backend. - - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - xla_test_library_deps: If set, the generated test targets will depend on the - respective cc_libraries generated by the xla_test_library rule. - backends: A list of backends to generate tests for. Supported values: "cpu", - "gpu". If this list is empty, the test will be generated for all supported - backends. - blacklisted_backends: A list of backends to NOT generate tests for. - args: Test arguments for the target. - tags: Tags for the target. - copts: Additional copts to pass to the build. - data: Additional data to pass to the build. - backend_tags: A dict mapping backend name to list of additional tags to - use for that target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. - **kwargs: Additional keyword arguments to pass to native.cc_test. - """ - test_names = [] - if not backends: - backends = all_backends - - backends = [backend for backend in backends - if backend not in blacklisted_backends] - - native.cc_library( - name="%s_lib" % name, - srcs=srcs, - copts=copts, - testonly=True, - deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], - ) - - for backend in filter_backends(backends): - test_name = "%s_%s" % (name, backend) - this_backend_tags = ["xla_%s" % backend] - this_backend_copts = [] - this_backend_args = backend_args.get(backend, []) - this_backend_data = [] - if backend == "cpu": - backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - elif backend == "gpu": - backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] - this_backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_deps = [] - backend_deps += plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] - this_backend_tags += plugins[backend]["tags"] - this_backend_args += plugins[backend]["args"] - this_backend_data += plugins[backend]["data"] - else: - fail("Unknown backend %s" % backend) - - if xla_test_library_deps: - for lib_dep in xla_test_library_deps: - backend_deps += ["%s_%s" % (lib_dep, backend)] - - tf_cc_test( - name=test_name, - srcs=srcs, - tags=tags + backend_tags.get(backend, []) + this_backend_tags, - extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + - this_backend_copts, - args=args + this_backend_args, - deps=deps + backend_deps, - data=data + this_backend_data, - **kwargs) - - test_names.append(test_name) - - native.test_suite(name=name, tests=test_names) - -def xla_test_library(name, - srcs, - hdrs=[], - deps=[], - backends=[]): - """Generates cc_library targets for the given XLA backends. - - This rule forces the sources to be compiled for each backend so that the - backend specific macros could expand correctly. It's useful when test targets - in different directories referring to the same sources but test with different - arguments. - - Examples: - - # Generates the targets: foo_test_library_cpu and foo_test_gpu. - xla_test_library( - name = "foo_test_library", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) - # Then use the xla_test rule to generate test targets: - xla_test( - name = "foo_test", - srcs = [], - backends = ["cpu", "gpu"], - deps = [...], - xla_test_library_deps = [":foo_test_library"], - ) + Args: + backends: A list of backends to filter. - Args: - name: Name of the target. - srcs: Sources for the target. - hdrs: Headers for the target. - deps: Dependencies of the target. - backends: A list of backends to generate libraries for. - Supported values: "cpu", "gpu". If this list is empty, the - library will be generated for all supported backends. - """ - - if not backends: - backends = all_backends - - for backend in filter_backends(backends): - this_backend_copts = [] - if backend in ["cpu", "gpu"]: - backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] - elif backend in plugins: - backend_deps = plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] + Returns: + The filtered list of backends. + """ + if cuda_is_configured(): + return backends else: - fail("Unknown backend %s" % backend) + return [backend for backend in backends if backend != "gpu"] + +def xla_test( + name, + srcs, + deps, + xla_test_library_deps = [], + backends = [], + blacklisted_backends = [], + args = [], + tags = [], + copts = [], + data = [], + backend_tags = {}, + backend_args = {}, + **kwargs): + """Generates cc_test targets for the given XLA backends. + + This rule generates a cc_test target for one or more XLA backends and also a + platform-agnostic cc_library rule. The arguments are identical to cc_test with + two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "gpu"), and + 'backend_args'/'backend_tags' specifies backend-specific args parameters to + use when generating the cc_test. + + The name of the cc_tests are the provided name argument with the backend name + appended, and the cc_library target name is the provided name argument with + "_lib" appended. For example, if name parameter is "foo_test", then the cpu + test target will be "foo_test_cpu" and the cc_library target is "foo_lib". + + The cc_library target can be used to link with other plugins outside of + xla_test. + + The build rule also defines a test suite ${name} which includes the tests for + each of the supported backends. + + Each generated cc_test target has a tag indicating which backend the test is + for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These + tags can be used to gather tests for a particular backend into a test_suite. + + Examples: + + # Generates the targets: foo_test_cpu and foo_test_gpu. + xla_test( + name = "foo_test", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + + # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu + # includes the additional arg "--special_cpu_flag". + xla_test( + name = "bar_test", + srcs = ["bar_test.cc"], + backends = ["cpu", "gpu"], + backend_args = {"cpu": ["--special_cpu_flag"]} + deps = [...], + ) + + The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} + to the value 1 where ${BACKEND} is the uppercase name of the backend. + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + xla_test_library_deps: If set, the generated test targets will depend on the + respective cc_libraries generated by the xla_test_library rule. + backends: A list of backends to generate tests for. Supported values: "cpu", + "gpu". If this list is empty, the test will be generated for all supported + backends. + blacklisted_backends: A list of backends to NOT generate tests for. + args: Test arguments for the target. + tags: Tags for the target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. + backend_tags: A dict mapping backend name to list of additional tags to + use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. + """ + test_names = [] + if not backends: + backends = all_backends + + backends = [ + backend + for backend in backends + if backend not in blacklisted_backends + ] native.cc_library( - name = "%s_%s" % (name, backend), + name = "%s_lib" % name, srcs = srcs, + copts = copts, testonly = True, - hdrs = hdrs, - copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] - + this_backend_copts, - deps = deps + backend_deps, + deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - -def generate_backend_suites(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - native.test_suite(name="%s_tests" % backend, - tags = ["xla_%s" % backend]) - - -def generate_backend_test_macros(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - manifest = "" - if backend in plugins: - manifest = plugins[backend]["disabled_manifest"] - - native.cc_library( - name="test_macros_%s" % backend, - testonly = True, - srcs = ["test_macros.cc"], - hdrs = ["test_macros.h"], - copts = [ - "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), - "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, - ], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:test", - ]) + for backend in filter_backends(backends): + test_name = "%s_%s" % (name, backend) + this_backend_tags = ["xla_%s" % backend] + this_backend_copts = [] + this_backend_args = backend_args.get(backend, []) + this_backend_data = [] + if backend == "cpu": + backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] + elif backend == "gpu": + backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] + this_backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_deps = [] + backend_deps += plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + this_backend_tags += plugins[backend]["tags"] + this_backend_args += plugins[backend]["args"] + this_backend_data += plugins[backend]["data"] + else: + fail("Unknown backend %s" % backend) + + if xla_test_library_deps: + for lib_dep in xla_test_library_deps: + backend_deps += ["%s_%s" % (lib_dep, backend)] + + tf_cc_test( + name = test_name, + srcs = srcs, + tags = tags + backend_tags.get(backend, []) + this_backend_tags, + extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + args = args + this_backend_args, + deps = deps + backend_deps, + data = data + this_backend_data, + **kwargs + ) + + test_names.append(test_name) + + native.test_suite(name = name, tests = test_names) + +def xla_test_library( + name, + srcs, + hdrs = [], + deps = [], + backends = []): + """Generates cc_library targets for the given XLA backends. + + This rule forces the sources to be compiled for each backend so that the + backend specific macros could expand correctly. It's useful when test targets + in different directories referring to the same sources but test with different + arguments. + + Examples: + + # Generates the targets: foo_test_library_cpu and foo_test_gpu. + xla_test_library( + name = "foo_test_library", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + # Then use the xla_test rule to generate test targets: + xla_test( + name = "foo_test", + srcs = [], + backends = ["cpu", "gpu"], + deps = [...], + xla_test_library_deps = [":foo_test_library"], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + hdrs: Headers for the target. + deps: Dependencies of the target. + backends: A list of backends to generate libraries for. + Supported values: "cpu", "gpu". If this list is empty, the + library will be generated for all supported backends. + """ + + if not backends: + backends = all_backends + + for backend in filter_backends(backends): + this_backend_copts = [] + if backend in ["cpu", "gpu"]: + backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + else: + fail("Unknown backend %s" % backend) + + native.cc_library( + name = "%s_%s" % (name, backend), + srcs = srcs, + testonly = True, + hdrs = hdrs, + copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + deps = deps + backend_deps, + ) + +def generate_backend_suites(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + native.test_suite( + name = "%s_tests" % backend, + tags = ["xla_%s" % backend, "-broken", "manual"], + ) + +def generate_backend_test_macros(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + manifest = "" + if backend in plugins: + manifest = plugins[backend]["disabled_manifest"] + + native.cc_library( + name = "test_macros_%s" % backend, + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + copts = [ + "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), + "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, + ], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], + ) diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index b1d18210eaafdfec0920c0cccaa0dfdbd6de5609..8b31e53707eee456e09adfe9fb76f03a8855056d 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = - ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); + auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); - auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({1.0f, 2.0f})); auto y = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); + client_->TransferToServer(LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } @@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); - Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); + auto tuple = LiteralUtil::MakeTuple({&elem}); + Call(&builder, callee, {ConstantLiteral(&builder, elem)}); - ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); + ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index a4eb57fc7b9abd460a7d158d0dc629eba88018cd..2f1510ff6969757f8091e9c043b61cb2a467ccd5 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); - auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1"); Add(p0, p1); auto param0_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto param1_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); @@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { auto computation = computation_status.ConsumeValueOrDie(); auto f32_literal = LiteralUtil::CreateR0(1.1f); - auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); + auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie(); auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = - client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); + client_->TransferToServer(f32_4_literal).ConsumeValueOrDie(); auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); - auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); + auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie(); // Match auto status = client_->Execute( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 59d917054be2ebe3a25f902f51972a682a5231b6..fbdf0fcb6543f09dedefef55cfe0f8a5d9067d5a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -17,18 +17,18 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -95,15 +95,14 @@ string ClientLibraryTestBase::TestName() const { } StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -114,18 +113,16 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr> -ClientLibraryTestBase::ExecuteAndTransferReference( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, +StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -138,7 +135,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference( } string ClientLibraryTestBase::ExecuteToString( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); @@ -150,29 +147,28 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return result.ValueOrDie()->ToString(); + return result.ValueOrDie().ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout) { + absl::Span arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); @@ -180,12 +176,12 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output) { // Try with no layout requirement. TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); - verify_output(*actual, ""); + verify_output(actual, ""); // Try with all output layouts. std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); @@ -196,8 +192,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, tensorflow::strings::StrCat( - "Test with output layout: ", + verify_output(actual, + absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); return Status::OK(); @@ -205,7 +201,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout) { @@ -221,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal->shape())) { + if (ShapeUtil::IsTuple(literal.shape())) { layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal->shape())); + ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -231,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); + std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = - literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); TF_ASSIGN_OR_RETURN(auto data, - client_->TransferToServer(*literal_relayout)); + client_->TransferToServer(literal_relayout)); arguments_with_layout.push_back(data.get()); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -252,15 +248,14 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // Every argument has an assigned layout. TF_ASSIGN_OR_RETURN( auto actual, - ExecuteAndTransfer( - computation, - tensorflow::gtl::ArraySlice(arguments_with_layout), - output_with_layout)); + ExecuteAndTransfer(computation, + absl::Span(arguments_with_layout), + output_with_layout)); string error_message = "Test with input layouts: "; for (const auto& str : layout_strings) { - tensorflow::strings::StrAppend(&error_message, str, " "); + absl::StrAppend(&error_message, str, " "); } - verify_output(*actual, error_message); + verify_output(actual, error_message); return Status::OK(); }; @@ -269,7 +264,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, + absl::Span arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -290,19 +285,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; - } else { - TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || - expected.shape().element_type() == PRED) - << ShapeUtil::HumanString(expected.shape()); } // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -327,14 +318,14 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); return Status::OK(); } Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, - ErrorSpec error, const Shape* shape_with_layout) { + absl::Span arguments_passed_in, ErrorSpec error, + const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -350,17 +341,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } - TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -386,13 +375,13 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, - tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::string_view expected, + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -401,66 +390,65 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + Literal expected_literal = LiteralUtil::CreateR1U8(expected); - VLOG(1) << "expected: " << expected_literal->ToString(); - VLOG(1) << "actual: " << actual->ToString(); + VLOG(1) << "expected: " << expected_literal.ToString(); + VLOG(1) << "actual: " << actual.ToString(); - EXPECT_EQ(expected, actual->GetR1U8AsString()); + EXPECT_EQ(expected, actual.GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, - ErrorSpec error) { + XlaBuilder* builder, absl::Span arguments, ErrorSpec error) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr, std::unique_ptr>> +StatusOr> ClientLibraryTestBase::ComputeValueAndReference( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's // into a vector to keep the data alive on the service until the end of this // function. @@ -546,7 +534,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { - auto array = MakeUnique>(rows, cols); + auto array = absl::make_unique>(rows, cols); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f) + offset; @@ -561,7 +549,7 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, int cols_padded) { CHECK_GE(rows_padded, rows); CHECK_GE(cols_padded, cols); - auto array = MakeUnique>(rows_padded, cols_padded, 0.0); + auto array = absl::make_unique>(rows_padded, cols_padded, 0.0); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f); @@ -580,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return ConstantLiteral(builder, use_bfloat16_ - ? *LiteralUtil::ConvertF32ToBF16(literal) - : literal); + ? LiteralUtil::ConvertF32ToBF16(literal) + : LiteralSlice(literal)); } std::unique_ptr @@ -611,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { if (use_bfloat16_) { - return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 4a6e8a31241d39db21935576d57f0acb17caef11..9d32f4f5174a57a53a9d3e6477b46fa4de852f7f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -30,14 +33,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -49,8 +49,8 @@ namespace xla { // use_bfloat16_params with that value. Returns the result. template std::vector ExpandUseBfloat16( - tensorflow::gtl::ArraySlice use_bfloat16_params, - tensorflow::gtl::ArraySlice specs) { + absl::Span use_bfloat16_params, + absl::Span specs) { std::vector expanded; for (bool use_bfloat16 : use_bfloat16_params) { for (const auto& spec : specs) { @@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test { string TestName() const; void SetFastMathDisabled(bool disabled) { - execution_options_.mutable_debug_options()->set_xla_enable_fast_math( - !disabled); + auto* opts = execution_options_.mutable_debug_options(); + opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_math(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } @@ -92,29 +93,29 @@ class ClientLibraryTestBase : public ::testing::Test { // execution options. Modify execution_options_ in your test if you want to // customize the options. StatusOr> Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, absl::Span arguments); - StatusOr> ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + StatusOr ExecuteAndTransfer( + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // This executes the computation via the reference client (which connects a // interpreter backend). The result is used as the expected values of the // computation. - StatusOr> ExecuteAndTransferReference( + StatusOr ExecuteAndTransferReference( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are @@ -124,102 +125,98 @@ class ClientLibraryTestBase : public ::testing::Test { // for integral types without the ErrorSpec parameter. template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span expected, + absl::Span arguments); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span expected, + absl::Span arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout = nullptr); - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error, + const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_layout = nullptr); Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. - void ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, - tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected, + absl::Span arguments); // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); + absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); @@ -285,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*LiteralUtil::CreateFromArray(argument), builder); + return AddParam(LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -300,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array), builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), + return CreateConstantFromLiteral(LiteralUtil::CreateR0(value), builder); } @@ -336,7 +333,7 @@ class ClientLibraryTestBase : public ::testing::Test { // converted to bfloat16. template std::unique_ptr CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array @@ -378,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + StatusOr> ComputeValueAndReference( + XlaBuilder* builder, absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -389,12 +385,12 @@ class ClientLibraryTestBase : public ::testing::Test { private: Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output); Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); @@ -414,130 +410,126 @@ class ClientLibraryTestBase : public ::testing::Test { template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + XlaBuilder* builder, absl::Span expected, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + XlaBuilder* builder, absl::Span expected, + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -545,27 +537,27 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR0(value); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR1(values); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -573,13 +565,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -587,13 +579,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -612,7 +604,7 @@ template std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index c898dacf489db97223e2918414daf5de88bece64..6f2ca84bb646e88af221ab80b727911ff7d990eb 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( - auto computed, client_->Transfer(*data, &expected_literal->shape())); + auto computed, client_->Transfer(*data, &expected_literal.shape())); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } @@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralSlice(*result, {0})); + LiteralSlice(result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralSlice(*result, {1})); + LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::GetTupleElementShape(result.shape(), 0), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}))); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::GetTupleElementShape(result.shape(), 1), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0}))); } @@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); + LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN( auto result_literal, - client_->Transfer(*results[0], &expected_result->shape())); + client_->Transfer(*results[0], &expected_result.shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 7c52c9fbbb57f9291ea9f0966e2efa715819fb67..6ef7ca035f75966bef12c7abcb55cb59e9b73655 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -38,25 +38,24 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: - void ExecuteComputationR0F32( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, float expected_result, - bool expect_cache_hit) { + void ExecuteComputationR0F32(const XlaComputation& computation, + absl::Span arguments, + float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - std::unique_ptr result = + Literal result = client_ ->ExecuteAndTransfer(computation, arguments, /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); + LiteralUtil::CreateR0(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } void ExecuteComputationR2F32( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; @@ -64,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase { ->Execute(computation, arguments, &execution_options_, &execution_profile) .ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data_handle).ConsumeValueOrDie(); + Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); + LiteralUtil::CreateR2(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -146,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = - client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(rowmaj_array).ConsumeValueOrDie(); auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = - client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 5a06d061f0d83fff547502495ff8ab13fb421b70..3b0414a6045a7c5f4f75948d8ccf2775c575626e 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr> ComputeConstantLiteral( - Client* client, const XlaOp& operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + StatusOr ComputeConstantLiteral(Client* client, const XlaOp& operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); TF_ASSIGN_OR_RETURN(auto computed, client->ComputeConstant(subgraph, output_layout)); @@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test { XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder, nullptr)); - return literal->Get({}); + return literal.Get({}); } bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { @@ -145,8 +145,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } @@ -161,8 +161,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } @@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR1({4, 6}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR0(5); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e..9811a015e91d866d6f4de6ebb6dac536ed6c7e06 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); auto x_literal = LiteralUtil::CreateR0(2.f); auto y_literal = LiteralUtil::CreateR0(3.f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, f32_scalar, "x"); @@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "z"); auto bcast = Broadcast(y, {5}); @@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "y"); auto y_bcast = Broadcast(y, {1, 5, 7}); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index b27c1044baf2c0002f166c53a81e4361c60d012a..32cac499c7439af80bafb88ac61b0b078f589599 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), - LiteralUtil::CreateR0(25.0f).get()}), + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(12.0f), + LiteralUtil::CreateR0(25.0f)}), {pred_arg.get()}, error_spec_); } @@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, CreateR1TupleFloorComputation()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), - LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({13.0f, 16.0f}), + LiteralUtil::CreateR1({26.0f, 30.0f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(true).get(), - LiteralUtil::CreateR0(12.2f).get(), - LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(true), + LiteralUtil::CreateR0(12.2f), + LiteralUtil::CreateR1({12.8f, 14.6f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(46.6f).get(), - LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) - .get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), - LiteralUtil::CreateR0(9.3f).get()}) - .get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(46.6f), + LiteralUtil::CreateR1({54.4f, 58.4f})}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({62.1f, 67.4f}), + LiteralUtil::CreateR0(9.3f)})}), {pred_arg.get()}, error_spec_); } @@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), - LiteralUtil::CreateR0(b).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), {x_arg.get(), y_arg.get()}, error_spec_); }; @@ -642,5 +638,57 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { test_swap(11.24f, 5.55f); } +// Test conditional that duplicates tuple elements in the then and else +// computations. This is a regression test for b/112550242. +XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { + const Shape scalar = ShapeUtil::MakeShape(S32, {}); + const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar}); + XlaComputation then_comp; + { + XlaBuilder builder(TestName() + ".then"); + auto p = Parameter(&builder, 0, tuple2, "then.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e0}); + then_comp = builder.Build().ConsumeValueOrDie(); + } + XlaComputation else_comp; + { + XlaBuilder builder(TestName() + ".else"); + auto p = Parameter(&builder, 0, tuple2, "else.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e1}); + else_comp = builder.Build().ConsumeValueOrDie(); + } + + { + // Pred is true case. + std::vector args; + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(true)); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } + { + // Pred is false case. + std::vector args; + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(false)); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 49375748319ad5fe40db507a034ec4b07adb7e84..72ff1e74a47c8584cb5336c86a1c978c4637a902 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D( Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *input_literal); + ConstantLiteral(&builder, input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})})); - std::unique_ptr result = - ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); + Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + LiteralSlice(result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(result, {1}), error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateToken()); + ConstantLiteral(&builder, LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 1adc68cc4839dcd7d89741ec016f27bc9047c9a5..5f063e67847487f1d18bf4ee80b1634ebdf4183a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -209,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -228,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -246,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, U32); @@ -263,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -280,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -317,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -447,15 +448,15 @@ std::vector GetInterestingF16ConversionTestCases() { XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::vector test_cases = GetInterestingF16ConversionTestCases(); std::vector input; - c_transform(test_cases, std::back_inserter(input), - [](float f) { return Eigen::half(f); }); + absl::c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); std::vector expected_output; - c_transform(input, std::back_inserter(expected_output), - [](Eigen::half h) { return static_cast(h); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast(h); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -470,12 +471,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::vector input = GetInterestingF16ConversionTestCases(); std::vector expected_output; - c_transform(input, std::back_inserter(expected_output), - [](float f) { return Eigen::half(f); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 7b6bbc4f571af2e11306f95c24e243e78e0f4f4e..fd98bf29b8a06d7476d51174b61c6268750db2ec 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -88,13 +88,12 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { - auto input_array = MakeUnique>(2, 3, 5, 5); + auto input_array = absl::make_unique>(2, 3, 5, 5); input_array->FillWithMultiples(0.1); - auto weight_array = MakeUnique>(4, 3, 1, 1); + auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 5ed8122e0073bde77bb2507a0ddd89c4365627c9..070b092d18930027e215cb43ff917e36cac99f12 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -26,16 +28,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -70,16 +70,16 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { const int kKernelSizeY = 2; const int kOutputActivationSizeZ = 256; const int kMiniBatchSize = 4; - auto alhs = - MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, - kInputActivationSizeY, kInputActivationSizeX); + auto alhs = absl::make_unique>( + kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY, + kInputActivationSizeX); alhs->FillWithMultiples(static_cast(1.0f)); ASSERT_EQ(3, alhs->width()); ASSERT_EQ(3, alhs->height()); - auto arhs = - MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, - kKernelSizeY, kKernelSizeX); + auto arhs = absl::make_unique>(kOutputActivationSizeZ, + kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); Array2D rhs_raster({ {1.0f, 0.0f}, // row 0 {0.0f, 0.0f}, // row 1 @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r5).ConsumeValueOrDie(); + client_->TransferToServer(filter_r5).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r5, + ComputeAndCompareLiteral(&builder, expected_r5, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -465,7 +465,7 @@ void iota_int_init_value(std::vector& values, int init_value) { } template -class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { +class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); @@ -498,30 +498,161 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } }; -TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); } +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); } + +template +class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 3, 3, 5}; + std::vector filter_dims = {3, 3, 1, 15}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(16029), static_cast(16218), static_cast(16407), + static_cast(17172), static_cast(17370), static_cast(17568), + static_cast(18369), static_cast(18576), static_cast(18783), + static_cast(19620), static_cast(19836), static_cast(20052), + static_cast(20925), static_cast(21150), static_cast(21375)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 6}; + std::vector filter_dims = {2, 2, 2, 12}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/3); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(5076), static_cast(5160), static_cast(5244), + static_cast(5328), static_cast(6164), static_cast(6264), + static_cast(6364), static_cast(6464), static_cast(7380), + static_cast(7496), static_cast(7612), static_cast(7728)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { + this->RunTest(); +} // Test fixture to run convolution tests with and without convolution // canonicalization enabled. @@ -561,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(param0)), - std::move(*LiteralUtil::CreateFromArray(param1))}, + {LiteralUtil::CreateFromArray(param0), + LiteralUtil::CreateFromArray(param1)}, error_spec_); } @@ -618,26 +749,25 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); auto expected_r1 = LiteralUtil::CreateR1(expect_elems); - auto expected_r3 = - expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); + auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + client_->TransferToServer(input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, + client_->TransferToServer(filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -737,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } @@ -760,9 +890,83 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { Array4D filter_data(1, 1, 1, 2); filter_data.FillIota(10); - ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}); + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}); +} + +XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100}); + Array4D input_data(1, 64, 100, 100); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321); + Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64}); + Array4D filter_data(7, 7, 1, 64); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = ConstantR4FromArray4D(&builder, filter_data); + + // Specify bf01_01io->bf01 as dimension numbers. + ConvolutionDimensionNumbers dnums; + // Input + dnums.set_input_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + // Kernel + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + // Output + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + dnums.add_output_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(3); + ConvGeneral(input, filter, /*window_strides=*/{1, 1}, + /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums, + /*feature_group_count=*/64); + + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)}, + error_spec_); +} + +class ConvolutionHloTest : public HloTestBase {}; + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[3,56,56,16] parameter(0) + %arg1 = f64[3,3,3,64] parameter(1) + ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[2,5,8,1] parameter(0) + %arg1 = f64[2,5,8,2] parameter(1) + ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %output = f64[4,5,16,16] parameter(0) + %kernel = f64[5,3,7,7] parameter(1) + %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3} + ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01 +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 6784c16715da72d337edf70fa51db42c59404136..ba3e9c436e3cfa574a07e881a187ff4c7d6243a1 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = - weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = ConstantLiteral(&builder, *weights_literal); + weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto weights = ConstantLiteral(&builder, weights_literal); auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = Rev(weights, {2, 3, 4}); ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { @@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = - activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = ConstantLiteral(&builder, *activations_literal); + activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); + auto activations = ConstantLiteral(&builder, activations_literal); auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = ConvGeneralDilated(activations, gradients, @@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); Transpose(forward_conv, {0, 1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 5ef273e5a26ea8a16db864974c9bfa2c296cbce8..1407e68d9a336b6bb1c960711015430f872aa912 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -40,49 +40,48 @@ class CopyOpTest : public HloTestBase { protected: void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + auto constant = + builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); + Literal result = ExecuteAndTransfer(std::move(module), {}); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); + TestCopyOp(LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); + TestCopyOp(LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { @@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = LiteralUtil::CreateR0(42.0); - Shape shape = literal->shape(); + Shape shape = literal.shape(); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {literal.get()}); - LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {&literal}); + LiteralTestUtil::ExpectR0Near(42.0f, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { @@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = - literal->mutable_shape_do_not_use()->mutable_layout(); + Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); @@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. - LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, *result, + LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, result, error_spec_); } @@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + Literal literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -182,14 +178,14 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR3EqualArray3D(a, *result); + LiteralTestUtil::ExpectR3EqualArray3D(a, result); } -void CopyOpTest::TestCopyConstantLayoutR4( - size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation) { +void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, + size_t n4, + absl::Span permutation) { Array4D a(n1, n2, n3, n4); for (size_t i = 0; i < n1; ++i) { for (size_t j = 0; j < n2; ++j) { @@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + Literal literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4( auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR4EqualArray4D(a, *result); + LiteralTestUtil::ExpectR4EqualArray4D(a, result); } XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) { @@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { XlaBuilder builder(TestName()); Parameter(&builder, 0, in_shape, "input"); - auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index d12a4e7fcd7813775a81677bcaa07af60ff9b477..410732c07b7b6d3ece33ab11f4778241dc53ca50 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = LiteralUtil::CreateR1({1, 2, 3}); - EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); + EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); } XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { @@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ( - *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0, &literal1})); } // On the GPU backend, constants get special handling. Someone might pass a @@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 13c777835eb2d2519d39205cdc96f0aac4850c7d..a693fa35954bcb2d95074c94d0aa3eabc1d5fd62 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { @@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, @@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest, module->AddEntryComputation(b.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D( - Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } class CustomCallClientAPITest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 5f234f36a8543ad408fb3430b27844beb16a54b5..86fd1ceb1368feedb14088fa7045224440f6c4f9 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { @@ -36,7 +36,7 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 2db6503afab748d7b778e26b2f9350ac64c7778b..e0f23b0fa807ca27038afa2eec5f739508e3d5bd 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -42,7 +42,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) @@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles1 = result_status1.ConsumeValueOrDie(); auto handles2 = result_status2.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); handles1[0].reset(); handles1[1].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { // the same as handle[3] and handle[1] should be the same as handle[2]. auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { // should not have been deallocated because of reference counting. global_data.reset(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0e9e92ed996fbb34826d19b670c7c4920a1aad13..0171f515839d556827f0723772214d175939d386 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -67,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), + LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -195,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -218,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -261,16 +262,14 @@ string PrintDotTestParam( const ::testing::TestParamInfo& test_param) { const DotTestParam& param = test_param.param; if (param.has_addend) { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F", - param.addend_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); } else { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); } } @@ -287,24 +286,23 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = - LiteralUtil::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( - param.dot_lhs_row_major))); + Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = - client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie(); std::unique_ptr> dot_rhs_data = MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); - std::unique_ptr dot_rhs_lit = + Literal dot_rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = - client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie(); std::unique_ptr> addend_data; - std::unique_ptr addend_lit; + Literal addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { @@ -312,7 +310,7 @@ void ParametricDotTest::TestImpl() { addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); - addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie(); } XlaBuilder builder(TestName()); @@ -478,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -512,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -585,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -593,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -631,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -669,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{9.0f, 10.0f}, {11.0f, 12.0f}}, {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) @@ -677,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) .ConsumeValueOrDie(); @@ -709,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { auto lhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -779,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -828,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7f6f203a1ba48e0053f799c58bbbeae87aef1f7f..7501c6d957e7afe99b8c530e5f0d575f818367da 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -114,23 +114,23 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, + void RunR1(absl::Span input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { // bfloat16 has explicit constructors, so it does not implicitly convert the // way built-in types do, which is why we can't take the parameter as an - // ArraySlice. We also can't convert it to a vector, because - // vector is special so that it cannot be an ArraySlice, which + // Span. We also can't convert it to a vector, because + // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie(); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*LiteralUtil::CreateR0(input_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(input_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_value = - std::move(*LiteralUtil::CreateR0(update_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(update_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_value = - std::move(*LiteralUtil::CreateR0(expected_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(expected_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -385,22 +385,22 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, - tensorflow::gtl::ArraySlice update_values_int, + void RunR1(absl::Span input_values_int, + absl::Span update_values_int, const std::vector slice_starts, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR1(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { - std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << literal->ToString(); + Literal literal = LiteralUtil::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal.ToString(); } }; @@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = ConstantLiteral(&builder, *input_literal); + auto input = ConstantLiteral(&builder, input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); @@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) { auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), *start_indices_literal, buffer)); + stream.get(), start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 5116e60ca63ef5f94b25b15e6616086fb9e44bbb..b08ece0e63e9472f657b49b533511e9b192d3212 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index bf1de02ba9dbd97db9ee31484402fe9b92385219..51b50d456e496c9c01c38fb8539bb3737de16937 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -38,29 +38,29 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - input_literal->Set({i - begin}, 0.0f); + input_literal.Set({i - begin}, 0.0f); } else { - input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + input_literal.Set({i - begin}, tensorflow::bit_cast(i)); } } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal->Get({i}))); + expected_result.push_back(evaluate_op(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 39cc6c5927f1d416e31f689487efc10c20371abe..3be9657db40a7ea073baca32d8a20ccd6fa8a274 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -37,10 +37,9 @@ class FloorCeilTest : public ClientLibraryTestBase { }; // Runs a computation and comparison on expected vs f(input) - void TestR1F32(tensorflow::gtl::ArraySlice input, - tensorflow::gtl::ArraySlice expected, Function f) { - LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") - << "}"; + void TestR1F32(absl::Span input, + absl::Span expected, Function f) { + LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}"; XlaBuilder builder(TestName()); auto c = ConstantR1(&builder, input); if (f == kCeil) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 792be0d3fcd55621b9f8cdf0fdc28f7bb49294d1..9c94acb437e9fc948a4255f7112e2e7a40cfa5fb 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -22,13 +22,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -42,14 +43,11 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -using tensorflow::gtl::ArraySlice; - namespace xla { namespace { @@ -113,26 +111,26 @@ class FusionTest : public HloTestBase { hlos[0] = builder.AddInstruction(std::move(root_hlo)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction( - ArraySlice(hlos, 0, Arity + 1), + absl::Span(hlos).subspan(0, Arity + 1), HloInstruction::FusionKind::kLoop); auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4))); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } } private: template - T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice xs); + T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); }; template <> float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -157,7 +155,7 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, template <> bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kEq: return xs[0] == xs[1]; @@ -224,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.5}, {2.72}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -250,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -285,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { // Every element of result should be y = x^2 = 4.0. for (int i = 0; i < rand_dim0_size; ++i) { for (int j = 0; j < dim1_size; ++j) { - EXPECT_EQ(4.0, result->Get({i, j})); + EXPECT_EQ(4.0, result.Get({i, j})); } } } @@ -310,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -325,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(5), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -340,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -355,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -370,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -385,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{7}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -400,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -415,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -430,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -445,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -461,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 2, 1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -479,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-3, -2, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -497,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -515,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -537,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-2, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -554,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { @@ -572,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -601,11 +599,11 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, - HloInstruction::FusionKind::kLoop); + HloInstruction::FusionKind::kInput); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -626,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(-15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -676,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -712,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({8}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -784,19 +782,17 @@ ENTRY main { } )"; - std::unique_ptr operand = - LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, - test_runner_.Execute(std::move(module), {operand.get()}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + test_runner_.Execute(std::move(module), {&operand}, + /*run_hlo_passes=*/false)); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - *result)); + LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + result)); } class FusionClientLibraryTest : public ClientLibraryTestBase {}; @@ -823,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { // where overflow is OK. Array2D arr(32, 32); arr.FillUnique(); - std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({0, 1})); - std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({1, 0})); - XlaOp p0 = AddParam(*l1, &b); + XlaOp p0 = AddParam(l1, &b); XlaOp sum = p0; for (int i = 1; i < kNumParams; ++i) { - auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); sum = sum + p0 * pN * pN; } @@ -881,19 +877,19 @@ void BM_ParallelFusion(int num_iters) { auto param0_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = - client->LiteralToShapedBuffer(*param0_literal, device_ordinal) + client->LiteralToShapedBuffer(param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = - client->LiteralToShapedBuffer(*param1_literal, device_ordinal) + client->LiteralToShapedBuffer(param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = - client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + client->LiteralToShapedBuffer(param2_literal, device_ordinal) .ConsumeValueOrDie(); // Build executable. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index b77bece85ad1b2192b04330af9e60d3a424b59f4..daa89398a697af9149797d621c3bdca80a00aedd 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -25,17 +25,16 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class GatherOperationTest : public HloTestBase { protected: void RunTest(const string& hlo_text, Literal* operand, - Literal* gather_indices) { - RunTest(hlo_text, {operand, gather_indices}); + Literal* start_indices) { + RunTest(hlo_text, {operand, start_indices}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -52,18 +51,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -74,18 +72,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -96,18 +93,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -118,18 +114,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -140,18 +136,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,1,1,2] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -162,20 +158,19 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -186,20 +181,19 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -210,18 +204,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -232,18 +225,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -254,17 +246,16 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -278,19 +269,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -304,19 +295,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = u32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -330,19 +321,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -356,19 +347,19 @@ ENTRY main { operand = u32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = u32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = u32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -379,17 +370,17 @@ ENTRY main { operand = s32[2,3,2]{2,1,0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), - output_window_dims={0,1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0}, + offset_dims={0,1,2}, + collapsed_slice_dims={}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1,3,2} + slice_sizes={1,3,2} } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -400,16 +391,16 @@ ENTRY main { operand = s32[4]{0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[] gather(operand, index), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1} + slice_sizes={1} } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -420,17 +411,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[0] parameter(1) ROOT gather = s32[0,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -441,21 +432,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -466,21 +456,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -491,21 +480,21 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -516,23 +505,22 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, @@ -544,23 +532,22 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -571,21 +558,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -596,21 +582,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -622,11 +607,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // ROOT gather = s32[2,3] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0}, - // gather_dims_to_operand_dims={0}, + // offset_dims={1}, + // collapsed_slice_dims={0}, + // start_index_map={0}, // index_vector_dim=1, - // window_bounds={1, 3} + // slice_sizes={1, 3} // } XlaBuilder builder("gather_basic"); @@ -637,9 +622,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { auto operand = Parameter(&builder, 0, operand_shape, "operand"); auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(1); - dim_numbers.add_elided_window_dims(0); - dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); dim_numbers.set_index_vector_dim(1); Gather(operand, indices, dim_numbers, {1, 3}); @@ -647,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr operand_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); + client_->TransferToServer(LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); @@ -664,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::vector> result_data, client_->ExecuteParallel(computation_instances)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, - *result_literal); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 51450314b611b49c643fb6fd5b0c0d2e7205a2d2..1115e50fe3120b7dbd891f07dedcacefa5ecf3ea 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -126,9 +126,8 @@ INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = - std::function)>; +using BinaryBuildFuncTy = std::function)>; struct BinaryOpTestParam { std::function compute_func; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index b662e837168c8b16daea0181786be19fa0237a8c..bdd4fd7e3d0f585d81e94a3326e6d24bb5c42f39 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -20,17 +20,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,9 +42,8 @@ namespace xla { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::optional; +using absl::optional; +using absl::string_view; constexpr char kInterpreter[] = "interpreter"; @@ -83,21 +85,50 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase() - : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} +HloTestBase::HloTestBase(bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, - se::Platform* reference_platform) + se::Platform* reference_platform, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = MakeUnique(/*allow_mixed_precision=*/true); + hlo_verifier_ = absl::make_unique( + /*layout_sensitive=*/verifier_layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); } -/* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, GetModuleConfigForTest()); + return absl::make_unique(name, GetModuleConfigForTest()); +} + +/* static */ +StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, + HloModule* module) { + const string module_str_before_run = module->ToProto().ShortDebugString(); + const auto status_or = hlo_pass->Run(module); + if (status_or.status().ok()) { + const string module_str_after_run = module->ToProto().ShortDebugString(); + if (!status_or.ValueOrDie()) { + // Check that the proto remains same. + EXPECT_EQ(module_str_after_run, module_str_before_run); + } + } + return status_or; } -/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { +/* static */ +PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfig::DEFAULT); + return precision_config; +} + +DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); @@ -105,24 +136,21 @@ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { return debug_options; } -StatusOr> HloTestBase::Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { +StatusOr HloTestBase::Execute(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } -std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { +Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) .ValueOrDie(); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { +Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -145,7 +173,8 @@ StatusOr> HloTestBase::MakeReferenceModule( } StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); @@ -159,12 +188,13 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( TF_ASSIGN_OR_RETURN(auto reference, reference_runner_.Execute(std::move(reference_module), arguments, run_hlo_passes)); - return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, error); } ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -177,7 +207,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -192,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( ::testing::AssertionResult HloTestBase::RunAndCompare( std::unique_ptr module, const optional& error, const std::function& reference_preprocessor) { - const auto& fake_arguments = - MakeFakeArguments(module.get()).ConsumeValueOrDie(); + auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompare(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -210,17 +240,16 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompare( - const StringPiece hlo_string, - const tensorflow::gtl::optional& error, + string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -233,8 +262,31 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + const auto& fake_arguments = + MakeFakeArguments(module_or_status.ValueOrDie().get()) + .ConsumeValueOrDie(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + return test_runner_ + .Execute(std::move(module_or_status.ValueOrDie()), + fake_argument_ptrs, /*run_hlo_passes=*/true) + .ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure(); +} + ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); @@ -247,8 +299,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - const StringPiece hlo_string, - const tensorflow::gtl::optional& error, + string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -262,7 +313,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); @@ -275,10 +326,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } HloComputation* HloTestBase::FindComputation(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { auto computations = module->computations(); - auto it = c_find_if(computations, - [&](HloComputation* c) { return c->name() == name; }); + auto it = absl::c_find_if( + computations, [&](HloComputation* c) { return c->name() == name; }); if (it == computations.end()) { return nullptr; } @@ -286,11 +337,11 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, } HloInstruction* HloTestBase::FindInstruction(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { for (const HloComputation* c : module->computations()) { auto instructions = c->instructions(); - auto it = c_find_if(instructions, - [&](HloInstruction* i) { return i->name() == name; }); + auto it = absl::c_find_if( + instructions, [&](HloInstruction* i) { return i->name() == name; }); if (it != instructions.end()) { return *it; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 66719b1460063a61541535ff7507468ae0ca1ada..0ae4bdc104d656946d45008adec9ea3960984545 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -72,49 +72,59 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); + std::unique_ptr CreateNewModule(const string& name = TestName()); + + // Runs the hlo_pass with the provided module and returns the result. This + // function also verifies that the module remains unchanged when hlo_pass + // returns false as the StatusOr value. + static StatusOr RunHloPass(HloPassInterface* hlo_pass, + HloModule* module); + + static PrecisionConfig DefaultPrecisionConfig(int operands); protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(); + HloTestBase(bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. - HloTestBase(se::Platform* test_platform, se::Platform* reference_platform); + HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. - static DebugOptions GetDebugOptionsForTest(); + // + // This function is virtual so tests can specify an alternative set of debug + // options (e.g. disabling additional passes). + virtual DebugOptions GetDebugOptionsForTest(); // Gets an HloModuleConfig with options appropriate for tests. - static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig GetModuleConfigForTest() { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); return config; } // Executes the given module and return the result as a Literal. - StatusOr> Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. - std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + Literal ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments); - std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + Literal ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments); // Executes the given hlo module on two backends and compares results. // @@ -129,8 +139,8 @@ class HloTestBase : public ::testing::Test { // modified. ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - const tensorflow::gtl::optional& error, + const absl::Span arguments, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -138,23 +148,21 @@ class HloTestBase : public ::testing::Test { // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - const tensorflow::gtl::optional& error, + const absl::Span arguments, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Executes an hlo module with fake inputs and compares the results. ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, - const tensorflow::gtl::optional& error, + std::unique_ptr module, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Same as above, except that the module will be executed without Hlo // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, - const tensorflow::gtl::optional& error, + std::unique_ptr module, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -162,21 +170,23 @@ class HloTestBase : public ::testing::Test { // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. ::testing::AssertionResult RunAndCompare( - const tensorflow::StringPiece hlo_string, - const tensorflow::gtl::optional& error, + const absl::string_view hlo_string, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const absl::string_view hlo_string) + TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareNoHloPasses( - const tensorflow::StringPiece hlo_string, - const tensorflow::gtl::optional& error, + const absl::string_view hlo_string, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -219,10 +229,8 @@ class HloTestBase : public ::testing::Test { // // This is useful for tests which create HLOs from a string and then want to // inspect a particular computation or instruction. - HloComputation* FindComputation(HloModule* module, - tensorflow::StringPiece name); - HloInstruction* FindInstruction(HloModule* module, - tensorflow::StringPiece name); + HloComputation* FindComputation(HloModule* module, absl::string_view name); + HloInstruction* FindInstruction(HloModule* module, absl::string_view name); // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } @@ -252,8 +260,8 @@ class HloTestBase : public ::testing::Test { // error happens before the results are computed, returns the error status. StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - const tensorflow::gtl::optional& error, bool run_hlo_passes, + const absl::Span arguments, + const absl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b..8bd0a729b77f3ec14204952cb0062103c823883e 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,67 +15,74 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(MakeUnique()) {} - -HloVerifiedTestBase::~HloVerifiedTestBase() { - // We can't call the ASSERT or EXPECT test macros in destructors, so we - // perform HLO verification in TearDown, and use the CHECK here to ensure - // users don't accidentally override the verification. - CHECK(tear_down_called_) - << "TearDown was never called; subclasses of HloVerifiedTestBase that " - << "override TearDown must call the superclass TearDown."; -} - -void HloVerifiedTestBase::TearDown() { - EXPECT_FALSE(tear_down_called_) - << "TearDown called more than once; it should be called exactly once."; - tear_down_called_ = true; - if (module_) { - VerifyModule(module_.get()); +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); } - for (int i = 0; i < modules_.size(); ++i) { - VerifyModule(modules_.at(i).get()); - } - HloTestBase::TearDown(); + return verifier_.Run(this).status(); } -void HloVerifiedTestBase::VerifyModule(HloModule* module) { - HloVerifier verifier(/*allow_mixed_precision=*/true); - xla::StatusOr mutated = verifier.Run(module); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; } } +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), + verifier_layout_sensitive_(layout_sensitive), + allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} + HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = HloTestBase::CreateNewModule(); + module_ = CreateNewVerifiedModule(TestName()); } return *module_; } HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(HloTestBase::CreateNewModule()); + modules_.emplace_back(CreateNewVerifiedModule(name)); return modules_.back().get(); } -void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, +void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); - VerifyModule(module_.get()); + module_ = CreateNewVerifiedModule(TestName()); + TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); + module_->VerifyOrAddFailure("after parsing"); } + +StatusOr> +HloVerifiedTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config) { + auto module = CreateNewVerifiedModule(TestName()); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + +std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b28c01c369fa1ae1c7941f5c8139882c4dbed08..388a99bb36408665edbc20ade6c6a733d64db88d 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -20,56 +20,84 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { -// A base class for HLO tests that stores a default HloModule, and automatically -// performs verification on that module on tear-down. +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + +// A base class for HLO tests that stores a default VerifiedHloModule. class HloVerifiedTestBase : public HloTestBase { protected: - HloVerifiedTestBase(); - ~HloVerifiedTestBase() override; + HloVerifiedTestBase(bool layout_sensitive = false, + bool allow_mixed_precision = false); // Constructs a default shape verifier. std::unique_ptr MakeShapeVerifier(); - // Performs verification on the default HloModule returned by module(). - // Automatically called by the testing framework for each test. - // - // REQUIRED: subclasses that override TearDown() must call this explicitly. - void TearDown() override; - // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + + ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") + void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); - // Sets the shape-size function used during hlo verification. If this isn't - // called, a default ShapeVerifier is used instead. - void SetShapeVerifier(std::unique_ptr shape_verifier) { - shape_verifier_ = std::move(shape_verifier); - } + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule* CreateNewModule(const string& name = TestName()); + // Creates and returns a verified HLO module with the given name. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + + private: // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. - private: + // // Lazily populated. Access via module(). - std::unique_ptr module_; + std::unique_ptr module_; + // Populated by calls to CreateNewModule. - std::vector> modules_; - std::unique_ptr shape_verifier_; - bool tear_down_called_ = false; - static void VerifyModule(HloModule* module); + std::vector> modules_; + + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c0263e811f94c90a69a460525ffa0c65127ebb5 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// This class includes unit tests which are expected to fail because invalid HLO +// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to +// include the necessary gunit parts to test this test machinery (needs the +// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the +// disabled tests enabled and failures can be manually compared against +// expectations. +class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; + +XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { + // Test shouldn't fail if no module is created at all. +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { + // Use module() to lazily create an empty module, build it up, and verify no + // failures. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { + // Use module() to lazily create an empty module and build up an invalid + // module. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); + + *hlo_module.entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { + // Call CreateNewModule and build up a valid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { + // Call CreateNewModule and build up a invalid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); + + *module->entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndVerifyModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + ParseAndVerifyModule(hlo_string); + EXPECT_EQ(module().entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} + +RANDOM GARBAGE +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleBad + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[1234] add(x,y) +} +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 17ac95ae0198d98490b25f7f2edd32d1e0495803..310f3495922250d68aa463fcbb24ef0b04603d09 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -23,40 +23,95 @@ limitations under the License. namespace xla { namespace { -class IotaTest : public ClientLibraryTestBase { - public: - explicit IotaTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) {} - template - std::vector GetExpected(const int64 num_elements) { - std::vector result(num_elements); - std::iota(result.begin(), result.end(), 0); - return result; +template +std::vector GetR1Expected(const int64 num_elements) { + std::vector result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; +} + +class IotaR1Test + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(IotaR1Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + Iota(&builder, element_type, num_elements); + if (element_type == F32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), {}, + ErrorSpec{0.0001}); + } else if (element_type == U32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); + } else { + CHECK_EQ(element_type, S32); + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); } -}; - -XLA_TEST_F(IotaTest, SimpleR1) { - for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { - { - XlaBuilder builder(TestName() + "_f32"); - IotaGen(&builder, F32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), {}, - ErrorSpec{0.0001}); - } - { - XlaBuilder builder(TestName() + "_u32"); - IotaGen(&builder, U32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } - { - XlaBuilder builder(TestName() + "_s32"); - IotaGen(&builder, S32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } +} + +INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test, + ::testing::Combine(::testing::Values(F32, U32, S32), + ::testing::Range(/*start=*/10, + /*end=*/10001, + /*step=*/10))); + +class IotaR2Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR2Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); } } +INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1))); + +class IotaR3Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR3Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1, 2))); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index cde1dcd9cd10c86107f495a92be42b57bf6a085b..554eb24d44168caa7d7252015e3d99f2d567df9b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -35,8 +35,7 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { int64 now_usec = tensorflow::Env::Default()->NowMicros(); string filename = tensorflow::io::JoinPath( tensorflow::testing::TmpDir(), - tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), - now_usec, name.c_str())); + absl::StrFormat("tempfile-%s-%x-%s", get_hostname(), now_usec, name)); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; @@ -94,7 +93,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( const LiteralSlice& expected, const LiteralSlice& actual, - const tensorflow::gtl::optional& error) { + const absl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 31a099c15f1f20457c90de97054f68a31eb49011..43cca91f64b2c0fbfde5054a361cf0f95302c23d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -32,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -62,7 +62,7 @@ class LiteralTestUtil { static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); template - static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Equal(absl::Span expected, const LiteralSlice& actual); template static void ExpectR2Equal( @@ -102,7 +102,7 @@ class LiteralTestUtil { const ErrorSpec& error); template - static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Near(absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error); template @@ -146,7 +146,7 @@ class LiteralTestUtil { // will be compared recursively. static ::testing::AssertionResult NearOrEqual( const LiteralSlice& expected, const LiteralSlice& actual, - const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + const absl::optional& error) TF_MUST_USE_RESULT; private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); @@ -155,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); + absl::Span expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2(expected), actual)); } template @@ -176,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, + absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2(expected), actual, error)); } template @@ -223,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3(expected), actual, error)); } template @@ -232,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index f297b2b847f570d26e71ddcd8e34bc626f982e1f..b6f9b8156b51144e4f74d285b1e4111d098f13c2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -31,11 +31,11 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal lhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + Literal rhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(64), + LiteralUtil::CreateR0(42), }); - CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; + CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal"; }; ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal"); } @@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto two = LiteralUtil::CreateR0(2); auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); - CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; + CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four"; }; tensorflow::Env* env = tensorflow::Env::Default(); @@ -80,20 +80,20 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { std::vector results; TF_CHECK_OK(env->GetMatchingPaths(pattern, &results)); - LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; + LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal->ToString()); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal->ToString()); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal->ToString()); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -103,10 +103,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto expected = LiteralUtil::CreateR1({1, 2, 3}); auto actual = LiteralUtil::CreateR1({4, 5, 6}); - ::testing::AssertionResult result = - LiteralTestUtil::Equal(*expected, *actual); - EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); - EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); + ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { @@ -114,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { @@ -122,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) { {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { @@ -130,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); - EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index e719da54d45d3e6eb3f3e14d3fa3076db2081e04..8d658695576035cdc34a213847460dd80de5f67e 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" @@ -125,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 6fc11150978931f980349799372872f9fb68f292..0487d314094edcab61a92de32f14113dd19673fa 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -51,8 +51,9 @@ void LlvmIrGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status()); + Status status = CompileToExecutable(std::move(hlo_module)).status(); ResetIrHook(); + TF_ASSERT_OK(status); StatusOr filecheck_result = RunFileCheck(ir_, pattern); TF_ASSERT_OK(filecheck_result.status()); @@ -73,9 +74,10 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - TF_ASSERT_OK( - CompileToAotCompilationResult(std::move(hlo_module), options).status()); + Status status = + CompileToAotCompilationResult(std::move(hlo_module), options).status(); ResetIrHook(); + TF_ASSERT_OK(status); StatusOr filecheck_result = RunFileCheck(ir_, pattern); ASSERT_TRUE(filecheck_result.ok()); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index e2cd5bcc5a95f692dcf4a43d717252bfe876aa81..dbdd20daf0c3a54ed7b6e2a9d3fb73274d77474a 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); @@ -53,12 +53,12 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { // deallocation happen on the right allocator. ExecutableRunOptions options; options.set_allocator(allocator); - tensorflow::gtl::optional result = + absl::optional result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), options); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_); // At least one allocation should have been performed when executing the // computation. @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index e310966d8b062f2baac00a17dd42cd449595d0d2..60eb21aafd23a8d724d1f08d5c87098b7c3dcd6b 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -92,10 +92,10 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 3); - CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer - CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer - CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer + CHECK_EQ(result->buffer_infos().size(), 3); + CHECK(result->buffer_infos()[0].is_entry_parameter()); // param buffer + CHECK_EQ(result->buffer_infos()[1].size(), sizeof(float)); // result buffer + CHECK(result->buffer_infos()[2].is_constant()); // const buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 1a823cf189b310c62c735419936544ea99fcfbaf..a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(123.f, ShapedBufferToLiteral(result), error_spec_); } @@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(165.f, ShapedBufferToLiteral(result), error_spec_); } @@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR1Near({}, ShapedBufferToLiteral(result), error_spec_); } @@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { @@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. ScopedShapedBuffer result_param_swap = ExecuteLocallyOrDie(computation, {&y_array, &x_array}); - LiteralTestUtil::ExpectR2Near( - {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_param_swap), error_spec_); + LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, + ShapedBufferToLiteral(result_param_swap), + error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. @@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_rowmaj), + ShapedBufferToLiteral(result_rowmaj), error_spec_); } @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {2})); + LiteralSlice(result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); + LiteralSlice(result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); + LiteralSlice(result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralSlice(result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), - LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}); + auto y_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}})}); - auto x_buffer = LiteralToShapedBuffer(*x_literal); - auto y_buffer = LiteralToShapedBuffer(*y_literal); + auto x_buffer = LiteralToShapedBuffer(x_literal); + auto y_buffer = LiteralToShapedBuffer(y_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); @@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) - .get(), - LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}), + LiteralUtil::CreateR1({222.0, -2.0, 10.0})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); + Literal result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); + LiteralSlice(result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - LiteralSlice(*result_0_literal, {1})); + LiteralSlice(result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); + Literal result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - LiteralSlice(*result_1_literal, {0})); + LiteralSlice(result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - LiteralSlice(*result_1_literal, {1})); + LiteralSlice(result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // Feed in a tuple where each two-element vector element is {tuple_index, // -tuple_index}. - std::vector> arg_elements; + std::vector arg_elements; for (int i = 0; i < kElementCount; ++i) { arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } - std::unique_ptr arg_literal = - LiteralUtil::MakeTupleOwned(std::move(arg_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements)); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_); } } @@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::vector> outer_tuple_elements; + std::vector outer_tuple_elements; for (int i = 0; i < kFanout; ++i) { - std::vector> inner_tuple_elements; + std::vector inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } @@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { } auto arg_literal = LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { - LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), - error_spec_); + LiteralTestUtil::ExpectR0Near(i + j + i * kFanout + j, + LiteralSlice(result_literal, {i, j}), + error_spec_); } } } @@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); + Literal arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { - std::vector> arg_vector; + std::vector arg_vector; arg_vector.push_back(std::move(arg_literal)); arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal(165.0, - LiteralSlice(*result_literal, index)); + LiteralSlice(result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { DefaultExecutableRunOptions().set_device_ordinal(d)); EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + ShapedBufferToLiteral(result)); } } } @@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. EXPECT_EQ(d, result.device_ordinal()); - LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + LiteralTestUtil::ExpectR0Equal(42.0f, ShapedBufferToLiteral(result)); } } @@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); + Literal tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - LiteralSlice(*tuple_literal, {0})); + LiteralSlice(tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - LiteralSlice(*tuple_literal, {1})); + LiteralSlice(tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; // Array shapes. - test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); - test_to_device_and_back(*LiteralUtil::CreateR0(true)); - test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(LiteralUtil::CreateR0(true)); + test_to_device_and_back(LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*LiteralUtil::MakeTuple({})); + test_to_device_and_back(LiteralUtil::MakeTuple({})); // Non-nested tuples. - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(12223.0)})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)})); // Nested tuple. - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()}) - .get(), - LiteralUtil::CreateR0(false).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)}), + LiteralUtil::CreateR0(false)})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456789000LL).get()})); + LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456789000LL)})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); Add(in, constant); - std::unique_ptr result; + Literal result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. thread.reset(); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { @@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + TF_ASSERT_OK_AND_ASSIGN(Literal result, local_client_->TransferFromOutfeedLocal( shape, local_client_->default_device_ordinal())); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } // Benchmark that measures the overhead of the LocalClient API when running a @@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) { auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, - buffer)); + ASSERT_IS_OK( + transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer)); const int kWarmups = 2; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index eaddf756dbc913dd9668cd22228fbd18c2c33309..f90ef22d2d549f451f8af231aea834e9f097b12a 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( +Literal LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { return local_client_->ShapedBufferToLiteral(shaped_buffer) .ConsumeValueOrDie(); @@ -156,7 +156,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -164,7 +164,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { return ExecuteLocally(computation, arguments, build_options, run_options) @@ -173,14 +173,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index b4477e9a6b23363ee3a1380f9f98f4b8226f6920..4027c7b124f8ac6e4b94600871ac32376a3e6467 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -86,26 +86,25 @@ class LocalClientTestBase : public ::testing::Test { // Construct and return a literal containing the array represented by // shaped_buffer. - std::unique_ptr ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Execute the given computation on the local client. With and without // options. StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 0732e195d44d738b264361e43d38259c26a4116e..4d327a6fe9c45174a0666fd573a081e0cfe450d2 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + Literal param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, @@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, @@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, @@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( @@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( @@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); @@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( + Literal param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( + Literal param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1}); @@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1, 2}); @@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - std::unique_ptr param2_literal = + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); + Literal param2_literal = LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = - client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); + client_->TransferToServer(param2_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); - auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2"); Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( @@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) { Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); @@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, @@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( @@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + Literal param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index da8c42d465340f2af3d6acd2c3676b69512f193f..3f278115e078877de1683574370df7790c2801fd 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -62,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { }); Exp(data); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -91,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { }); Map(&builder, {data}, add_half, {0, 1}); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -110,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { }); Max(lhs, rhs); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -133,10 +134,9 @@ class TestLinspaceMaxParametric float from = -128.0, to = 256.0; std::unique_ptr> alhs = MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); + auto arhs = absl::make_unique>(rows, cols, static_cast(1.0f)); - XlaBuilder builder( - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + XlaBuilder builder(absl::StrFormat("max_%dx%d_linspace", rows, cols)); auto lhs = ConstantR2FromArray2D(&builder, *alhs); auto rhs = ConstantR2FromArray2D(&builder, *arhs); Max(lhs, rhs); @@ -158,7 +158,7 @@ class TestLinspaceMaxParametric string PrintTestLinspaceMaxParam( const ::testing::TestParamInfo& test_param) { const TestLinspaceMaxParam& param = test_param.param; - return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); + return absl::StrCat(param.rows, "r", param.cols, "c"); } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 @@ -200,14 +200,12 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index eb06b115daa96bccd73de30bb7fa30733a6fd947..56aaeb0e6878737e6c689e8065d8f1e1871b3472 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -36,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -46,18 +47,26 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; - class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } + // Layout assignment assumes that there are no fusions in the input graph. + // Since the purpose of this test is to send pre-fused graphs to XLA, we have + // to do layout assignment ourselves. + DebugOptions GetDebugOptionsForTest() override { + auto opts = HloTestBase::GetDebugOptionsForTest(); + opts.add_xla_disable_hlo_passes("layout-assignment"); + return opts; + } + void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); - const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + const Shape elem_shape2 = + ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0}); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); @@ -80,13 +89,13 @@ class MultiOutputFusionTest : public HloTestBase { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { - auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( - ArraySlice({sub, add2}, 0, 2))); + auto tuple = + computation->AddInstruction(HloInstruction::CreateTuple({sub, add2})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0)); auto gte1 = computation->AddInstruction( @@ -100,23 +109,25 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); arg1.PopulateWithValue(2.5f); - Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); + Literal literal_r0 = LiteralUtil::CreateR0(-9.0f); auto actual = - ExecuteAndTransfer(std::move(hlo_module), - {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1}); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); - const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + const Shape elem_shape_F32 = + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); + const Shape elem_shape_U8 = + ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape_F32, "0")); auto param1 = builder.AddInstruction( @@ -136,17 +147,18 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {size, 1}), add)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, + dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { - auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( - ArraySlice({sub_U8, add}, 0, 2))); + auto tuple = computation->AddInstruction( + HloInstruction::CreateTuple({sub_U8, add})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0)); @@ -161,15 +173,14 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0(ShapeUtil::MakeShape(F32, {size})); + Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size})); input0.PopulateWithValue(2.5f); - Literal input1(ShapeUtil::MakeShape(F64, {size})); + Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = - std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); + Literal expect = LiteralUtil::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } }; @@ -206,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { LiteralUtil::CreateR0(1.0)), LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), LiteralUtil::CreateR0(4))); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -235,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -268,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); } const char* const kScalarOps = R"( @@ -291,7 +299,7 @@ const char* const kScalarOps = R"( XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -312,18 +320,17 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -344,18 +351,17 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -377,18 +383,17 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR1({36, 64}), - LiteralUtil::CreateR1({66, 138})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -410,19 +415,18 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -444,20 +448,19 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -480,21 +483,20 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) init1 = f32[] parameter(1) @@ -518,18 +520,18 @@ XLA_TEST_F(MultiOutputFusionTest, LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); auto init2 = LiteralUtil::CreateR0(6); - std::unique_ptr result = ExecuteNoHloPasses( - std::move(module), {param.get(), init1.get(), init2.get()}); + Literal result = + ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{167, 172}, {176, 180}}), LiteralUtil::CreateR2({{6, 6}, {6, 8}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { p0 = f16[2,2,2]{2,1,0} parameter(0) convert = f32[2,2,2]{2,1,0} convert(p0) @@ -553,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest, auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}}), LiteralUtil::CreateR3( @@ -564,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}})), - *result)); + result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c530591c6e5fe75658dd507d794f8b6a64442871 --- /dev/null +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +StatusOr BuildComputation() { + XlaBuilder b("computation"); + Shape scalar_s32 = ShapeUtil::MakeShape(S32, {}); + XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32); + return b.Build( + OutfeedWithToken(GetTupleElement(infeed, 0) + + ConstantLiteral(&b, LiteralUtil::CreateR0(1)), + GetTupleElement(infeed, 1), scalar_s32, "")); +} + +void CompileAndExecute( + LocalExecutable* executable, int device_ordinal, LocalClient* client, + absl::Mutex* results_mutex, + std::vector>>* results) { + xla::ExecutableRunOptions execute_options; + execute_options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + execute_options.set_device_ordinal(device_ordinal); + execute_options.set_allocator( + xla::ClientLibrary::GetXlaService(client->platform()) + ->backend() + .memory_allocator()); + StatusOr result = executable->Run({}, execute_options); + { + absl::MutexLock lock(results_mutex); + results->emplace_back(device_ordinal, std::move(result)); + } +} + +void TestWithDeviceCount(const int device_count) { + // Run `device_count` copies of the XLA program built by BuildComputation. + TF_ASSERT_OK_AND_ASSIGN( + se::Platform* const platform, + perftools::gputools::MultiPlatformManager::PlatformWithName("Host")); + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + TF_ASSERT_OK_AND_ASSIGN( + LocalClient* const client, + xla::ClientLibrary::GetOrCreateLocalClient(client_options)); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{})); + std::vector threads; + absl::Mutex results_mutex; + std::vector>> results; + tensorflow::Env* env = tensorflow::Env::Default(); + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + tensorflow::Thread* t = env->StartThread( + tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal), + [&executable, device_ordinal, client, &results_mutex, &results] { + CompileAndExecute(executable.get(), device_ordinal, client, + &results_mutex, &results); + }); + threads.push_back(t); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(client->TransferToInfeedLocal( + LiteralUtil::CreateR0(device_ordinal * 100), device_ordinal)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK_AND_ASSIGN(Literal outfeed, + client->TransferFromOutfeedLocal( + ShapeUtil::MakeShape(S32, {}), device_ordinal)); + EXPECT_EQ(outfeed, LiteralUtil::CreateR0(device_ordinal * 100 + 1)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + delete threads[device_ordinal]; + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(results[device_ordinal].second.status()); + } +} + +// NB! This test requires --xla_force_host_platform_device_count=4 + +TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); } + +TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); } + +TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); } + +TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); } +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc index 0a0426adcbc1b5b89be0841fa2c4204e2b65abf4..f2460822a61fef11e5c35c731fa6eca5df72b60b 100644 --- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { GetTupleElement(result_tuple, 0); TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { VLOG(1) << "Transferring trip count to computation"; // Transfer number of iterations to Infeed. TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(1))); // Pick up value from outfeed { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 1); + EXPECT_EQ(r.Get({}), 1); } VLOG(1) << "Writing data to infeed"; // Transfer some stuff to Infeed for use inside of loop. TF_ASSERT_OK(local_client_->TransferToInfeed( - *LiteralUtil::CreateR1({10, 20}))); + LiteralUtil::CreateR1({10, 20}))); // Pick up value from outfeed { VLOG(1) << "Reading from body outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&xfeed_shape)); - EXPECT_EQ(r->Get({0}), 11); - EXPECT_EQ(r->Get({1}), 21); + EXPECT_EQ(r.Get({0}), 11); + EXPECT_EQ(r.Get({1}), 21); } { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 0); + EXPECT_EQ(r.Get({}), 0); } // Joins the thread thread.reset(); - EXPECT_EQ(comp_result->Get({}), 0); + EXPECT_EQ(comp_result.Get({}), 0); } XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { @@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { })); TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(true))); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&result_shape)); - EXPECT_EQ(r->Get({}), true); + EXPECT_EQ(r.Get({}), true); // Join the thread thread.reset(); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index ca21b0b2ba590a6daadf2c8d3d9ad213514b0f0f..6e98167739c234fae335bcc9e024423e7fc87197 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*LiteralUtil::CreateR0(1.5), &b), + AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); @@ -140,7 +140,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1.0f, 2.0f}, // row 0 {3.0f, 4.0f}, // row 1 @@ -148,10 +148,10 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(1.5); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 0, 0, 1) = 2.0f; @@ -168,10 +168,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); Pad(AddParam(input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(8, 5, 1, 1); + auto expected = absl::make_unique>(8, 5, 1, 1); expected->Fill(pad_value); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 2, 0, 0) = 2.0f; @@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -269,7 +269,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { XLA_TEST_F(PadTest, Pad4DU8Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1, 2}, // row 0 {3, 4}, // row 1 @@ -280,7 +280,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { Pad(AddParam(*input, &b), ConstantR0(&b, 35), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(35); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 2; @@ -301,13 +301,13 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. - auto zeros = MakeUnique>(2, 3, 3, 2); - auto ones = MakeUnique>(2, 3, 3, 2); + auto zeros = absl::make_unique>(2, 3, 3, 2); + auto ones = absl::make_unique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(0); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 1; @@ -321,7 +321,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { XLA_TEST_P(PadTestFloat, Large2DPad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(4, 4); + auto ones = absl::make_unique>(4, 4); ones->Fill(1.0f); auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -342,7 +342,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(0.0f); auto input = AddParam(*operand, &b); @@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), - padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -368,7 +367,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -395,7 +394,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -423,7 +422,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -446,19 +445,18 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { // Regression test for b/31827337. XLA_TEST_P(PadTestFloat, ReducePad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(2, 2, 2, 2); + auto ones = absl::make_unique>(2, 2, 2, 2); ones->Fill(1.0); auto input = AddParam(*ones, &b); XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), - padding_config); + Pad(reduce, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index f6c762e7a4bee91a26c4c2e033c3717fef6d91d0..dcb4c11c3ccab5992e1ea4fadf02fda8ff77e7ea 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + Literal param0_literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); @@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); @@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + Literal param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), @@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); @@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); @@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); // Use both parameters // @@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + Literal literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); @@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + Parameter(&builder, 1, literal1.shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + Literal literal1 = LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1.shape(), "param2"); // This add is unused. Add(param1, param2); @@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + Literal literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( - client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + client_->TransferToServer(literal).ConsumeValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR0(i); + Literal literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); sum_handle = Add(sum_handle, param); } @@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); - parameter_shapes.push_back(literal->shape()); + parameter_shapes.push_back(literal.shape()); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); + Literal bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( - std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + std::move(client_->TransferToServer(bool_literal)).ValueOrDie()); XlaOp bool_param = - Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param"); params.push_back(bool_param); - parameter_shapes.push_back(bool_literal->shape()); + parameter_shapes.push_back(bool_literal.shape()); auto init = Tuple(&builder, params); @@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest, param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({i, i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR1({4, 5, 6}), })) .ConsumeValueOrDie(); @@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + Literal literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); - const Shape original = literal->shape(); + const Shape original = literal.shape(); { // Reverse the layout present in original, and make that the layout of the // literal. @@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape_do_not_use()->mutable_layout() = + *literal.mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, literal->Get({0, 1})); + ASSERT_EQ(2, literal.Get({0, 1})); } // Use the original shape in building the computation. XlaBuilder builder(TestName()); @@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); // Check that we got the off-diagonal value that we expected. Array2D expected(1, 1); expected(0, 0) = 2; diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 2fc7f816b56db6f57ca835d1847476b6d622ce5e..58539e6b061b0cec1cc660b52e78894e5deeea56 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -31,7 +31,7 @@ class PredTest : public ClientLibraryTestBase { protected: void TestCompare(bool lhs, bool rhs, bool expected, std::function)> + absl::Span)> op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 326e13b3867f2f804e882e00e35850d0189ad8d7..8f2c26f0eea9c7a3b33cd77e5977924c1659535a 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -37,9 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, - tensorflow::gtl::ArraySlice dims, - int64 seed = 42); + Literal UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each @@ -50,8 +48,8 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest( - T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { +Literal PrngTest::UniformTest(T a, T b, absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -60,8 +58,8 @@ std::unique_ptr PrngTest::UniformTest( SetSeed(seed); auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions())); + actual.EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { constexpr int64 count = 100; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); - result->Literal::EachCell( - [&](tensorflow::gtl::ArraySlice, bfloat16 value) { - int64 index = static_cast((value - low) / interval); - counts[index]++; - }); + result.EachCell([&](absl::Span, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); } // Each bucket should have similar amount of counts. That is, not more than // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 @@ -149,8 +146,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell([&counts](tensorflow::gtl::ArraySlice, - int32 value) { ++counts[value]; }); + actual.EachCell( + [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); @@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { }; XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, - client_->TransferToServer(*param0_literal)); + client_->TransferToServer(param0_literal)); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto fn = build_sum_rng(builder); Map(&builder, {param0}, fn, {0}); @@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), - ShapeUtil::ElementsIn(param0_literal->shape())); - for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { - EXPECT_GE(actual->data()[i], param0_literal->data()[i]); - EXPECT_LT(actual->data()[i], - param0_literal->data()[i] + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()), + ShapeUtil::ElementsIn(param0_literal.shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) { + EXPECT_GE(actual.data()[i], param0_literal.data()[i]); + EXPECT_LT(actual.data()[i], param0_literal.data()[i] + 1.0f); } } @@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); - std::unique_ptr result1; + Literal result1; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options1)); } - std::unique_ptr result2; - std::unique_ptr result3; + Literal result2; + Literal result3; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options1)); } - std::unique_ptr result4; - std::unique_ptr result5; - std::unique_ptr result6; + Literal result4; + Literal result5; + Literal result6; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index a080dd1732bde21712cf47b4b57538cf4040f30e..c9096fb29b2019796c42b69de80c63b5fc7c5c3a 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,11 +15,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -29,16 +29,13 @@ limitations under the License. namespace xla { namespace { -namespace str_util = tensorflow::str_util; -namespace strings = tensorflow::strings; - struct ReduceLayout { std::array input_minor_to_major; std::array output_minor_to_major; string ToString() const { - return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_", - str_util::Join(output_minor_to_major, "x")); + return absl::StrCat(absl::StrJoin(input_minor_to_major, "x"), "_", + absl::StrJoin(output_minor_to_major, "x")); } }; @@ -95,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = LiteralUtil::CreateR4( + Literal reduce_input = LiteralUtil::CreateR4( {{ /*i0=0*/ {/*i1=0*/ {-0.246092796, -0.179497838, -0.161181688}, diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 531648fe3eb8e3941c5e3c012847ee68c616590f..26e2bfde5cdc19657640f24f31bc008d09ad7106 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -57,8 +58,8 @@ static const int mantissa_sizes[] = {23, 10, 23, 10}; string TestDataToString(const ::testing::TestParamInfo data) { int i = data.param; - return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_", - mantissa_sizes[i], "_mantissa_bits"); + return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i], + "_mantissa_bits"); } // The FPVAL macro allows us to write out the binary representation of the @@ -230,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR1({input_values}); + Literal a_literal = LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); ReducePrecision(a, exponent_bits, mantissa_bits); @@ -254,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // Abs doesn't affect resolution. auto abs = Abs(a); @@ -283,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -309,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -333,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -358,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2065271a7f686c52c88df80b0efe8f2e1542d198..83997cdac21c437d460dabdbdfdb31100b1359af 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,6 +32,9 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -51,7 +54,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -79,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase { }, 4); // clang-format on CHECK(ShapeUtil::Equal( - literal_3d_->shape(), + literal_3d_.shape(), ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) - << literal_3d_->shape().ShortDebugString(); + << literal_3d_.shape().ShortDebugString(); } // Runs an R1 => R0 reduction test with the given number of elements. @@ -100,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase { input_data[i] *= -1; } } - std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (float item : input_data) { @@ -113,8 +114,7 @@ class ReduceTest : public ClientLibraryTestBase { ErrorSpec(0.001)); } - void RunR1ToR0PredTest(bool and_reduce, - tensorflow::gtl::ArraySlice input_data) { + void RunR1ToR0PredTest(bool and_reduce, absl::Span input_data) { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); @@ -133,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + Literal input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); bool expected = and_reduce; for (bool item : input_data) { @@ -174,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::array expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -208,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (int64 rowno = 0; rowno < rows; ++rowno) { @@ -236,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -259,8 +256,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments, ErrorSpec(0.01, 1e-4)); } @@ -269,8 +266,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments); } @@ -294,15 +291,14 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to - // ArraySlice. + // Span. std::unique_ptr expected(new NativeT[cols]); for (int64 colno = 0; colno < cols; ++colno) { NativeT column_result = initial_value; @@ -314,7 +310,7 @@ class ReduceTest : public ClientLibraryTestBase { } ComputeAndCompareGeneric( - &builder, tensorflow::gtl::ArraySlice(expected.get(), cols), + &builder, absl::Span(expected.get(), cols), {input_global_data.get()}); } @@ -351,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase { reference_reduction_function_for_uints, unsigned_int_identity); } - std::unique_ptr literal_2d_; - std::unique_ptr literal_3d_; + Literal literal_2d_; + Literal literal_3d_; uint32 seed_ = 0xdeadbeef; }; @@ -449,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -481,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -510,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - MakeFakeLiteral(input_shape)); + TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -530,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 major = 0; major < 2; ++major) { @@ -556,12 +548,11 @@ struct BoundsLayout { }; void PrintTo(const BoundsLayout& spec, std::ostream* os) { - *os << tensorflow::strings::Printf( - "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(), - spec.bounds.size() - spec.reduce_dims.size(), - tensorflow::str_util::Join(spec.bounds, "x").c_str(), - tensorflow::str_util::Join(spec.layout, "").c_str(), - tensorflow::str_util::Join(spec.reduce_dims, "").c_str()); + *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(), + spec.bounds.size() - spec.reduce_dims.size(), + absl::StrJoin(spec.bounds, "x"), + absl::StrJoin(spec.layout, ""), + absl::StrJoin(spec.reduce_dims, "")); } // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. @@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( @@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; @@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); @@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); @@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); @@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); + input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); Reduce(input_activations, ConstantR0(&builder, 0.0f), add, GetParam().reduce_dims); @@ -866,21 +857,17 @@ INSTANTIATE_TEST_CASE_P( BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}}, BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}})); -// TODO(b/64093391) Disabled on GPU due to an assertion failure when running -// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on -// 2017-07-26. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { +XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = - LiteralUtil::CreateR1({1.0f, 4.0f}); + Literal b_literal = LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b = Parameter(&builder, 0, b_literal.shape(), "b"); Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); @@ -907,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest { std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, - max_fn, {0}); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn, + {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -955,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = - LiteralUtil::CreateR1(operand); + Literal input_literal = LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Literal input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = - client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + client_->TransferToServer(input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( &builder, expected, {input_global_data.get(), input_global_data2.get()}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 1bd6fdab31d6c3516339bdb98459ffe3bbdef1d1..c25ccafaf83cf1b29095a77eefa357d9af08dc60 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -18,6 +18,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -35,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -54,7 +57,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { if (use_bfloat16()) { - return ErrorSpec(1e-1, 5e-2); + return ErrorSpec(2e-1, 6e-2); } else { return ErrorSpec(1e-3, 1e-3); } @@ -67,10 +70,10 @@ class ReduceWindowTest : public ::testing::WithParamInterface, ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } void ReduceWindowAdd(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), @@ -78,8 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMax(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); @@ -89,8 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMin(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); @@ -104,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); + LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -121,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({100, 1}), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), {}, ErrorSpec(0.00001)); } @@ -158,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -173,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -187,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -204,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -226,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -249,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -274,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -291,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -310,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -329,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected), {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -349,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -357,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = MakeUnique(shape); - arg_literal->PopulateWithValue(1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + Literal arg_literal(shape); + arg_literal.PopulateWithValue(1.0f); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); @@ -368,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = MakeUnique(result_shape); - expected->PopulateWithValue(27.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + Literal expected(result_shape); + expected.PopulateWithValue(27.0f); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -410,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -432,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -454,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -475,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -501,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -518,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -537,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { @@ -553,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -578,21 +575,20 @@ string R4ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // - "__layout_", tensorflow::str_util::Join(param.layout, "_"), // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), // + "__pad_high_", absl::StrJoin(param.pad_high, "x"), // + "__layout_", absl::StrJoin(param.layout, "_"), // (param.reducer == kAdd) ? "_add" : "_max"); CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -611,12 +607,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillIota(1); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + input.FillRandom(0.1f, 0.1f); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(4); @@ -625,9 +620,16 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); - auto computation = param.reducer == kAdd + auto reducer = param.reducer; + if (use_bfloat16() && Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + + auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); ReduceWindowWithGeneralPadding( @@ -638,8 +640,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window_strides=*/param.strides, /*padding=*/padding); - CHECK(param.reducer == kAdd || param.reducer == kMax); - auto reduce_func = param.reducer == kAdd + CHECK(reducer == kAdd || reducer == kMax); + auto reduce_func = reducer == kAdd ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; std::unique_ptr> expected = @@ -650,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - std::unique_ptr expected_literal = - LiteralUtil::CreateFromArray(*expected); + Literal expected_literal = LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( - input_literal->shape().element_type(), - AsInt64Slice(expected_literal->shape().dimensions()), param.layout); - ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + input_literal.shape().element_type(), + AsInt64Slice(expected_literal.shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()}, DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -807,6 +808,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{1, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + /*window_bounds=*/{1, 64, 64, 1}, + /*strides=*/{1, 64, 64, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 0, 2, 1}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64}, + /*window_bounds=*/{112, 112, 1, 8}, + /*strides=*/{112, 112, 1, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -928,21 +945,42 @@ struct R3ReduceWindowTestData { {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, }; string R3ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__padding_", param.padding == Padding::kSame ? "same" : "valid", - "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", + absl::StrJoin(param.window_bounds, "x"), "__strides_", + absl::StrJoin(param.strides, "x"), "__padding_", + param.padding == Padding::kSame ? "same" : "valid", "__layout_", + param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", + param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -954,35 +992,41 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } }; -TEST_P(R3ReduceWindowTest, Add) { +TEST_P(R3ReduceWindowTest, DoIt) { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], - param.base_bounds[2], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + param.base_bounds[2]); + input.FillRandom(0.1f, 0.1f); + Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + auto reducer = param.reducer; + if (use_bfloat16()) { + input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); + if (Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + } - XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); + + auto computation = reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - auto expected = ReferenceUtil::ReduceWindow3DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); - - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1068,17 +1112,16 @@ string R2ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__layout_", param.layout[0], "_", param.layout[1], // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_", + absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_", + param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1092,16 +1135,14 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1111,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1127,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1261,21 +1302,35 @@ struct R1ReduceWindowTestData { /*pad_low=*/{5}, /*pad_high=*/{0}, /*reducer=*/Reducer::kAdd}, + + // The pattern generated by inclusive scan (cumsum/cumprod). + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4095}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, + + // The pattern generated by exclusive scan (cumsum/cumprod). + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4096}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, }; string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = + absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), + "__strides_", absl::StrJoin(param.strides, "x"), + "__pad_low_", absl::StrJoin(param.pad_low, "x"), + "__pad_high_", absl::StrJoin(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1295,11 +1350,11 @@ TEST_P(R1ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); - std::unique_ptr input_literal = - LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + Literal input_literal = + LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + auto input_arg = + CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; @@ -1308,7 +1363,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1320,14 +1375,14 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; auto expected = ReferenceUtil::ReduceWindow1DGeneric( - /*operand=*/tensorflow::gtl::ArraySlice(input_vector), + /*operand=*/absl::Span(input_vector), /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } @@ -1341,7 +1396,7 @@ INSTANTIATE_TEST_CASE_P( // results on the interpreter backend. class ReduceWindowTextTest : public HloTestBase {}; -TEST_F(ReduceWindowTextTest, R2General256x384) { +XLA_TEST_F(ReduceWindowTextTest, R2General256x384) { const string hlo_string = R"( HloModule R2Window mul { @@ -1358,7 +1413,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { +XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { const string hlo_string = R"( HloModule R2Window mul { @@ -1375,7 +1430,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2General2x5) { +XLA_TEST_F(ReduceWindowTextTest, R2General2x5) { const string hlo_string = R"( HloModule R2Window mul { @@ -1392,7 +1447,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { +XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { const string hlo_string = R"( HloModule R2Window mul { @@ -1410,7 +1465,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { +XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { const string hlo_string = R"( HloModule R3Window mul { @@ -1428,7 +1483,7 @@ ENTRY R3Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(HloTestBase, ReduceWindowIdentity) { +XLA_TEST_F(HloTestBase, ReduceWindowIdentity) { const string hlo_string = R"( HloModule ReduceWindowIdentity identity.pad_to_reduce_window { @@ -1442,10 +1497,10 @@ ENTRY reduce-window-identity { } )"; - EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } -TEST_F(HloTestBase, ReduceWindowS32) { +XLA_TEST_F(HloTestBase, ReduceWindowS32) { const string hlo_string = R"( HloModule reduce-window @@ -1461,7 +1516,26 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { } )"; - EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); +} + +XLA_TEST_F(HloTestBase, ReduceWindowF16) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] { + %param0 = f16[] parameter(0) + ROOT %param1 = f16[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { + %parameter.0 = f16[81,8]{1,0} parameter(0) + %parameter.1 = f16[] parameter(1) + ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index d8914513819415368a628eab1f482f9644dd46b1..5cf87e565bf493167f5173588e7afa3b96282488 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. - LiteralTestUtil::ExpectR0Equal(4, *literal); + LiteralTestUtil::ExpectR0Equal(4, literal); } XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { @@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{x_data.get(), y_data.get()}, @@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { .ConsumeValueOrDie(); // Expect 5. - LiteralTestUtil::ExpectR0Equal(5, *literal); + LiteralTestUtil::ExpectR0Equal(5, literal); } TEST_F(ReplayTest, MapPlusTwoOverR1) { @@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. - LiteralTestUtil::ExpectR1Equal({3, 4, 5}, *literal); + LiteralTestUtil::ExpectR1Equal({3, 4, 5}, literal); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 368f5583c9ce3773e57b858ff7606f679346529a..ae24eb5eb4822a2057e34a1aec8b7d64604d8984 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 382d1b1ae741285dcd1f7761edb82a5c333887af..dedc95b5ae8315185a35f786af42aad53bd7ad96 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(1.0f); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = LiteralUtil::CreateR1({-1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = LiteralUtil::CreateFromArray(expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); @@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { {35, 36, 37}, {40, 41, 42}, {45, 46, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); @@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { {45, 16, 26}, {36, 46, 17}, {27, 37, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); auto expected_literal = LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); - ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&b, expected_literal, {input.get()}, zero_error_spec_); } } @@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}, {}); EXPECT_THAT( @@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), @@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, {1, 0}); - std::unique_ptr actual = + Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); - std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralUtil::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(expected); } - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {{204, 205, 206, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {{206, 7, 107, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -689,20 +688,17 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -711,20 +707,17 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -734,25 +727,23 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); - input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { + input.Each([&](absl::Span indices, float* cell) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -762,14 +753,13 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); input_array.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -778,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, {2, 3, 0, 1}); - std::unique_ptr output_literal = + Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, &execution_options) @@ -787,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); - EXPECT_EQ(expected->data(), output_literal->data()); + auto expected = LiteralUtil::ConvertF32ToBF16(input_literal); + EXPECT_EQ(expected.data(), output_literal.data()); } else { - EXPECT_EQ(input_literal->data(), output_literal->data()); + EXPECT_EQ(input_literal.data(), output_literal.data()); } } @@ -801,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); + ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()}); } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { @@ -816,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -833,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); + ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()}); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { @@ -842,27 +832,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::vector bounds = {2, 2, 2, 2}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { @@ -871,27 +859,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::vector bounds = {1, 1, 250, 300}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { @@ -900,27 +886,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::vector bounds = {5, 5, 1, 10}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { @@ -930,27 +914,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::vector bounds = {5, 5, 10, 1}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { @@ -959,27 +941,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::vector bounds = {3, 3, 1, 3}; std::vector new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({0, 1, 2, 3})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) - ->Relayout(input_literal->shape().layout()); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal) + .Relayout(input_literal.shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 41e49b4003236d55d85592315652a0ddefd5c485..4e55b0d7ac4453d074500f3a7fda96cb5ab52c56 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -37,16 +39,14 @@ static std::array use_bfloat16_params{false}; #endif struct ReverseSpec { - tensorflow::gtl::ArraySlice input_dims; - tensorflow::gtl::ArraySlice reversal; + absl::Span input_dims; + absl::Span reversal; bool use_bfloat16; string ToTestCaseName() const { - return tensorflow::strings::Printf( - "reverse_%s_in_dims_%s_%s", - tensorflow::str_util::Join(input_dims, "x").c_str(), - tensorflow::str_util::Join(reversal, "x").c_str(), - use_bfloat16 ? "bf16" : "f32"); + return absl::StrFormat( + "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"), + absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32"); } }; @@ -83,26 +83,25 @@ TEST_P(FloatReverseTest, Reverses) { ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); auto r1_literal = LiteralUtil::CreateR1(input_vector); - auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto a = AddParam(*input_literal, &builder); + auto a = AddParam(input_literal, &builder); Rev(a, spec.reversal); - std::unique_ptr expected = input_literal->CloneToUnique(); + Literal expected = input_literal.Clone(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float) { - for (int64 i = 0; i < indices.size(); ++i) { - output_indices[i] = indices[i]; - } - float value = input_literal->Get(indices); - for (int64 dim : spec.reversal) { - output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; - } - expected->Set(output_indices, value); - }); - ComputeAndCompareLiteral(&builder, *expected, {}); + expected.EachCell([&](absl::Span indices, float) { + for (int64 i = 0; i < indices.size(); ++i) { + output_indices[i] = indices[i]; + } + float value = input_literal.Get(indices); + for (int64 dim : spec.reversal) { + output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; + } + expected.Set(output_indices, value); + }); + ComputeAndCompareLiteral(&builder, expected, {}); } INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index a620fe19085d98c8b6642b25b159d6c2308bdae2..091a5d2cacce6ac5bf986776e5ec96612d08cc75 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -38,7 +38,7 @@ namespace { class RoundTripPackedLiteralTest : public ClientLibraryTestBase { protected: // Sends the literal to the server and retrieves it back. - std::unique_ptr RoundTripToServer(const Literal& original) { + Literal RoundTripToServer(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); return client_->Transfer(*data).ConsumeValueOrDie(); @@ -47,8 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase { TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { string data(sizeof(float) * 2, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 2); + absl::Span floats(tensorflow::bit_cast(data.data()), 2); floats[0] = 42.0; floats[1] = 24.0; @@ -60,18 +59,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, actual->Get({0})); - EXPECT_EQ(24.0, actual->Get({1})); + EXPECT_EQ(42.0, actual.Get({0})); + EXPECT_EQ(24.0, actual.Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With x as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=0,x=1 @@ -89,24 +87,22 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({0, 1})); - EXPECT_EQ(64.0f, actual->Get({1, 0})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({0, 1})); + EXPECT_EQ(64.0f, actual.Get({1, 0})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With y as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=1,x=0 @@ -124,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({1, 0})); - EXPECT_EQ(64.0f, actual->Get({0, 1})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({1, 0})); + EXPECT_EQ(64.0f, actual.Get({0, 1})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index a8193c2eac05ba4f0df339909f3e82a28ac35253..cd5a531603b0cb6e0f48f4dcd49891cbd5428602 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase { void RoundTripTest(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); + Literal result = client_->Transfer(*data).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralTestUtil::Equal(original, result)); } }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(1.0), LiteralUtil::CreateR1({2, 3})})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc index b4f2b74e3dc9e80f50454b28eb6f2502cef3e681..2b03a0b0b22eb0ae4777417f6640c5f90171d808 100644 --- a/tensorflow/compiler/xla/tests/sample_text_test.cc +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -19,18 +19,18 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class SampleTextTest : public HloTestBase {}; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc..1dd937a6d0656b53f9e7e0cb25acf80f0c3d59c0 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -46,9 +46,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { // A template for building and running a binary comparison test. template void TestCompare(NativeT lhs, NativeT rhs, bool expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); @@ -58,9 +57,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); @@ -163,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = LiteralUtil::CreateR0(value); + Literal a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), {a_data.get()}); } @@ -227,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + Literal a_literal = LiteralUtil::CreateR0(2.1f); + Literal b_literal = LiteralUtil::CreateR0(5.5f); + Literal c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); std::unique_ptr c_data = - client_->TransferToServer(*c_literal).ConsumeValueOrDie(); + client_->TransferToServer(c_literal).ConsumeValueOrDie(); - XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); - XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); - XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c"); Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -379,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, @@ -390,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend / divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -421,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, @@ -432,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend % divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -443,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); - TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); + Literal literal = LiteralUtil::CreateR0(87919); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b21dd56045e1dc11847e213852dea60cd033be7b --- /dev/null +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -0,0 +1,624 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using absl::nullopt; + +class ScatterTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* scatter_indices, Literal* updates) { + RunTest(hlo_text, {operand, scatter_indices, updates}); + } + + void RunTest(const string& hlo_text, absl::Span args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Add + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Mul + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { + const string hlo_text = R"( +HloModule TensorFlowScatter_F32 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, NegativeIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}}); + Literal updates = LiteralUtil::CreateR3({{{-10, 10}, {-40, 40}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + updates = s32[1,3,2]{2,1,0} parameter(2) + ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={0,1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + Literal operand = LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = + LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, ScalarUpdate) { + const char* hlo_text = R"( +HloModule ScalarUpdate + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + updates = s32[] parameter(2) + ROOT scatter = s32[4]{0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, EmptyIndices) { + const string hlo_text = R"( +HloModule EmptyIndices + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[0] parameter(1) + updates = s32[0] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = LiteralUtil::CreateR1({1, 2, 3}); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index e3d4f98dd7432d1dce7e697586e8b17105dc82e7..f737b5158b3622d677aea5bf64a421a56e2c42dd 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -42,8 +42,8 @@ struct SelectAndScatterTestParam { std::vector operand_shape; std::vector source_shape; Padding padding_type; - tensorflow::gtl::ArraySlice window_dimensions; - tensorflow::gtl::ArraySlice window_strides; + absl::Span window_dimensions; + absl::Span window_strides; }; class SelectAndScatterTest diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index b8ad6668f80a3002eff3cc458997966ee67c8d4b..2cc33ab0963afe8ba2d8e9a6972dcf0622e27c48 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -18,6 +18,11 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -25,16 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -using ::tensorflow::str_util::Join; - class SliceTest : public ClientLibraryTestBase {}; TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { @@ -175,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), - &expected_literal->shape()); + ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001), + &expected_literal.shape()); } struct R1Spec { @@ -193,26 +194,26 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - // This can't be an std::vector, since you can't grab an ArraySlice of a + // This can't be an std::vector, since you can't grab a Span of a // vector. - tensorflow::gtl::InlinedVector input(spec.input_dim0); + absl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); - auto original = Parameter(&builder, 0, literal->shape(), "p0"); + auto original = Parameter(&builder, 0, literal.shape(), "p0"); Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); // Ditto. - tensorflow::gtl::InlinedVector expected; + absl::InlinedVector expected; for (int i = spec.slice_start; i < spec.slice_limit; i += spec.slice_stride) { expected.push_back(i); } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -222,9 +223,8 @@ class SliceR1LargeTest : public SliceR1Test {}; string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; - return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, - spec.slice_start, spec.slice_limit, - spec.slice_stride); + return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start, + spec.slice_limit, spec.slice_stride); } XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } @@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) { input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = Parameter(&builder, 0, literal->shape(), "p0"); + auto a = Parameter(&builder, 0, literal.shape(), "p0"); Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2(&builder, *expected, {arg.get()}); @@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P( R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, // + R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, // R2Spec{ 511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, // R2Spec{ @@ -448,13 +449,11 @@ struct R4Spec { string R4SpecToString(const ::testing::TestParamInfo& data) { const R4Spec& spec = data.param; - return tensorflow::strings::StrCat( // - "input_", Join(spec.input_dims, "x"), // - "__layout_", Join(spec.input_layout, ""), // - "__starts_", Join(spec.slice_starts, "x"), // - "__limits_", Join(spec.slice_limits, "x"), // - "__strides_", Join(spec.slice_strides, "x") // - ); + return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"), + "__layout_", absl::StrJoin(spec.input_layout, ""), + "__starts_", absl::StrJoin(spec.slice_starts, "x"), + "__limits_", absl::StrJoin(spec.slice_limits, "x"), + "__strides_", absl::StrJoin(spec.slice_strides, "x")); } class SliceR4Test : public ClientLibraryTestBase, @@ -469,9 +468,9 @@ class SliceR4Test : public ClientLibraryTestBase, XlaBuilder builder(TestName()); auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal.shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index be35ec6c6ee4c015755622b2dc9bb92e23af7c85..a9874a918659f1d7403ba0c5cb968e62d7091936 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/str_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" @@ -44,7 +46,7 @@ ManifestT ReadManifest() { string contents((std::istreambuf_iterator(file_stream)), std::istreambuf_iterator()); - std::vector lines = tensorflow::str_util::Split(contents, '\n'); + std::vector lines = absl::StrSplit(contents, '\n'); for (string& line : lines) { auto comment = line.find("//"); if (comment != string::npos) { @@ -53,8 +55,8 @@ ManifestT ReadManifest() { if (line.empty()) { continue; } - tensorflow::str_util::StripTrailingWhitespace(&line); - std::vector pieces = tensorflow::str_util::Split(line, ' '); + absl::StripTrailingAsciiWhitespace(&line); + std::vector pieces = absl::StrSplit(line, ' '); CHECK_GE(pieces.size(), 1); auto& platforms = manifest[pieces[0]]; for (int64 i = 1; i < pieces.size(); ++i) { @@ -73,8 +75,7 @@ string PrependDisabledIfIndicated(const string& test_case_name, // First try full match: test_case_name.test_name // If that fails, try to find just the test_case_name; this would disable all // tests in the test case. - auto it = manifest.find( - tensorflow::strings::StrCat(test_case_name, ".", test_name)); + auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); if (it == manifest.end()) { it = manifest.find(test_case_name); if (it == manifest.end()) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2647937013222ccfdae98b0c1d141f461020b5c9..5155f0c652c7c6dbba60c421159494fa28072090 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tests/test_utils.h" +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { @@ -26,89 +29,102 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - // Create uniform numbers between 1 and 1.125 to avoid creating denormal - // numbers. - std::uniform_real_distribution generator(1.0f, 1.125f); - const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice indices) { - // Generate a random uniform number from -0.0625 and 0.0625 and bias it - // with a position dependent number with mean 0.037109375. These number - // should allow for long chains of accumulation without being too close - // to zero or too large to accumulate all numbers accurately. Only do - // this for large literals where the number of elements is much greater - // than 47 otherwise only negative values are produced. - // - // The value is positionally biased using a product of the indices. Add - // one to each index value to avoid collapsing to zero if any of the - // indices are zero. - int64 index_product = 1; - for (int64 i : indices) { - index_product *= (1 + i); - } - const int64 negative_bias = should_index_bias ? 47 : 0; - FloatT index_bias = - static_cast(index_product % 113 - negative_bias) / - static_cast(256.0f); - return static_cast(generator(*engine) - 1.0625f) + index_bias; - })); + if (no_duplicates) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + FloatT next_value = std::numeric_limits::min(); + for (FloatT& value : literal->data()) { + value = next_value; + next_value = + std::nextafter(next_value, std::numeric_limits::max()); + } + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); + } + } } template void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine); + PopulateWithRandomFloatingPointDataImpl(literal, engine, + no_duplicates); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for half types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine); + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (half& value : literal->data()) { + value = static_cast(generator(*engine)); + } } -// The standard library does not have a case for bfloat16, unsurprisingly, so we -// handle that one specially. template <> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for bfloat types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), BF16); - std::uniform_real_distribution generator(-0.9f, 1.0f); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return static_cast(generator(*engine)); - })); + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (bfloat16& value : literal->data()) { + value = static_cast(generator(*engine)); + } } template -void PopulateWithRandomIntegralData(Literal* literal, - std::minstd_rand0* engine) { +void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(*engine); - })); + if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) < + std::numeric_limits::max()) { + std::iota(literal->data().begin(), literal->data().end(), 0); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), std::numeric_limits::max()); + for (IntT& value : literal->data()) { + value = generator(*engine); + } + } } // Similar to MakeFakeLiteral but takes a random number generator engine to -// enable reusing the engine across randomly generated literals. -StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine) { +// enable reusing the engine across randomly generated literals. 'no_duplicates' +// indicates that there should be no duplicate values in each generated +// array. This is uniqueness is best-effort only. Some types (half and bfloat16) +// are not supported and uniqueness cannot be guaranteed if the number of +// elements exceeds the number of different values supported by the type. +StatusOr MakeFakeLiteralInternal(const Shape& shape, + std::minstd_rand0* engine, + bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; + std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteralInternal(element_shape, engine)); + TF_ASSIGN_OR_RETURN( + Literal element, + MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); @@ -116,48 +132,52 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = MakeUnique(shape); + Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(&literal, engine, + no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(&literal, engine, + no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(&literal, engine, + no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(&literal, engine, + no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { + TF_CHECK_OK( + literal.Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -167,7 +187,7 @@ StatusOr> MakeFakeLiteralInternal( break; default: return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return std::move(literal); } @@ -176,6 +196,7 @@ enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. ConstantType GetInitValue(const HloComputation& computation) { + // TODO(b/77635120): Add init values, for min, max, and their arg variants. const HloInstruction* const root = computation.root_instruction(); if (computation.num_parameters() != 2 || root->operand_count() != 2 || root->operand(0)->opcode() != HloOpcode::kParameter || @@ -200,24 +221,20 @@ bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - return ( - ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); + return ((opcode == HloOpcode::kReduceWindow && op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2) || + (opcode == HloOpcode::kReduce && + op_num >= instruction->operand_count() / 2)); } // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape, - std::minstd_rand0* engine) { - const int64 rank = ShapeUtil::Rank(input_shape); - std::vector start_indices(rank); +Literal MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { + std::vector start_indices(index_space.size()); if (engine != nullptr) { - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); + for (int i = 0; i < index_space.size(); ++i) { + std::uniform_int_distribution generator(0, index_space[i]); start_indices[i] = generator(*engine); } } @@ -254,6 +271,11 @@ std::vector FindConstrainedUses( auto converted_uses = FindConstrainedUses(dataflow, *instruction); constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); + } else if (opcode == HloOpcode::kSort && + instruction->operand_count() == 2 && op_num == 0) { + // Operand 0 of sort is the array of keys used for key/value + // (two-operand) kSort instructions. + constrained_uses.push_back(instruction); } } } @@ -264,106 +286,141 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr> CreateLiteralForConstrainedUses( - const tensorflow::gtl::ArraySlice constrained_uses, +StatusOr CreateLiteralForConstrainedUses( + const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - HloInstruction* needs_index = nullptr; - HloInstruction* needs_constant = nullptr; + std::vector index_space; + bool no_duplicates = false; + bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr) { - auto needs_index_shape = needs_index->shape(); - auto use_shape = use->shape(); - if (needs_index->opcode() == HloOpcode::kDynamicSlice) { - needs_index_shape = needs_index->operand(0)->shape(); + case HloOpcode::kDynamicUpdateSlice: { + const Shape& indexed_shape = use->operand(0)->shape(); + const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice + ? use->shape() + : use->operand(1)->shape(); + const int64 rank = ShapeUtil::Rank(indexed_shape); + if (!index_space.empty()) { + TF_RET_CHECK(rank == index_space.size()); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = std::min( + index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i)); } - if (use->opcode() == HloOpcode::kDynamicSlice) { - use_shape = use->operand(0)->shape(); - } - if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + } else { + index_space.resize(rank); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); } } - needs_index = use; break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; + case HloOpcode::kSort: + no_duplicates = true; + break; + default: return Unimplemented( "Constrained operand generation not implemented for %s.", - use->ToString().c_str()); + use->ToString()); } } - if (needs_index != nullptr && needs_constant != nullptr) { - return Unimplemented( - "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "constant: %s\n", - needs_index->ToString().c_str(), needs_constant->ToString().c_str()); + int constraint_count = 0; + constraint_count += no_duplicates ? 1 : 0; + constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += needs_constant ? 1 : 0; + if (constraint_count > 1) { + return Unimplemented("Conflicting operand generation constraints."); } - if (needs_index != nullptr) { - return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape(), engine); - } else if (needs_constant != nullptr) { + if (!index_space.empty()) { + return MakeRandomIndex(index_space, engine); + } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: - return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()); case ConstantType::kOne: - return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, + /*no_duplicates=*/false); } } else { - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates); } } // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param, - std::minstd_rand0* engine) { +StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, + const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random) { - auto engine = pseudo_random ? MakeUnique() : nullptr; - return MakeFakeLiteralInternal(shape, engine.get()); +StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random) { + auto engine = + pseudo_random ? absl::make_unique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random) { +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random) { + auto engine = + pseudo_random ? absl::make_unique() : nullptr; + return MakeFakeArguments(module, engine.get()); +} + +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - auto engine = pseudo_random ? MakeUnique() : nullptr; - std::vector> arguments(params.size()); + std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( - *dataflow, *params[i], engine.get())); + arguments[i] = + MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); } return std::move(arguments); } -Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { - return HloVerifier(allow_mixed_precision).Run(module).status(); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision) { + return HloVerifier(/*layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision) + .Run(module) + .status(); } +std::unique_ptr CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index e59f215a9a3ace80d7a23e1bbc40970c7a63ea0d..b3c8a739058475a4e51bae6ad2a98152a6532b47 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -20,14 +20,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -57,14 +57,23 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random. -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random = true); +StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // -// Will handle special cases such as making sure that indices used for dynamic -// slices are bounded, reduces that call adds use 0 as an init value, etc. +// A best-effort attempt is made to generate the data in a way which produce +// stable computation results across platforms. Specifically: +// +// (1) Init values of reductions should be the identity of the reduction +// computation. +// +// (2) Indices of dynamic slices and update slices should be in bounds. +// +// (3) Keys of key/value sorts should contain no duplicates. +// +// These constraints are best-effort only. // // If pseudo_random is true, the generated numbers will be generated // deterministically in a pseudo random way unless the values are constrated to @@ -75,14 +84,26 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random = true); +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random = true); + +// Overload which accepts a random number generator. This enables generation of +// different random values with sequential calls to MakeFakeArguments by reusing +// the same generator. +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine); // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(HloModule* const module, - bool allow_mixed_precision = false); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision); +// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of +// the LHS with dimension 0 of the RHS with no batch dimensions. +// Both LHS and the RHS must be of rank 2. +std::unique_ptr CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index a2f0338e25977d7c76dbc48b3afc649b77ba4ee2..181e5cbe290b0df0cf605cc4ef4b8a4945b3d367 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -72,5 +73,106 @@ XLA_TEST_F(TestUtilsTest, Token) { TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + const Literal& index_arg = args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + update_param.1 = f32[1,2,3]{0,1,2} parameter(3) + update_param.2 = f32[3,2,2]{0,1,2} parameter(4) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 5); + const Literal& index_arg = args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort.148.1589 + +ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { + %parameter.0 = f32[1048576]{0} parameter(0) + %parameter.1 = s32[1048576]{0} parameter(1) + ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = args[0]; + + tensorflow::gtl::FlatSet key_set; + for (const float& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + } +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort.148.1589 + +ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { + %parameter.0 = s32[1048576]{0} parameter(0) + %parameter.1 = s32[1048576]{0} parameter(1) + ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = args[0]; + + tensorflow::gtl::FlatSet key_set; + for (const int32& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 2bdbd08309a81b201fc224110805549f7fb5bb55..b34fd0f2e873214c509533f29553af914ddc984d 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -35,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { @@ -51,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -67,7 +64,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -84,7 +84,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { "param")); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -101,7 +104,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT(status.error_message(), ::testing::HasSubstr( @@ -185,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(42, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(42, result.Get({})); } { @@ -196,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(false); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(7, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(7, result.Get({})); } } diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 125513ddfd16cb4e742e7d589e22b721307621ee..d6641d257a75945be94d299a1bd4b0366e3759b7 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = LiteralUtil::CreateR0(42); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR0(42); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR0Equal(42, *result); + LiteralTestUtil::ExpectR0Equal(42, result); } XLA_TEST_F(TransferManagerTest, TransferR1F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, - *result); + result); } XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1(test_vector); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR1Equal(test_vector, *result); + LiteralTestUtil::ExpectR1Equal(test_vector, result); } XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1U8(test_string); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_EQ(result->GetR1U8AsString(), test_string); + EXPECT_EQ(result.GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( - LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LayoutUtil::Equal(result.shape().layout(), literal.shape().layout())); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple({}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = LiteralUtil::CreateR1( + Literal literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( + Literal literal = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1( - {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) - .get(), - LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), - LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}), + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f))}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { @@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { // supported. auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(456.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); - - auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); - auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + Literal literal1 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + Literal literal2 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(456.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-98.0f, 153.0f})}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1.shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; auto stream2 = stream_->GetOrCreateSubStream(); - std::unique_ptr result1, result2; + Literal result1, result2; // Round trip literals through device in multiple streams asynchronously. for (int i = 0; i < kIterationCount; ++i) { - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1, device_buffer1)); - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2, device_buffer2)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result1, + Literal this_result1, transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result2, + Literal this_result2, transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); result1 = std::move(this_result1); result2 = std::move(this_result2); } - EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2)); } class TransferDeviceToHostBenchmark : public TransferManagerTest { @@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); } tensorflow::testing::StopTiming(); @@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); } tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 2fd70b72b52f360fc74a73cd13d401b7dac6e708..619d2a388b5646c31f0a61f709a2ab3067e39c03 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -50,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests a tuple made of scalar constants. @@ -65,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar1).get(), - LiteralUtil::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar1), + LiteralUtil::CreateR0(constant_scalar2)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests the creation of tuple data. @@ -87,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of tuple data. @@ -101,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(7.0), LiteralUtil::CreateR1({})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of an empty tuple. @@ -112,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. @@ -195,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2(constant_matrix), + LiteralUtil::CreateR1(constant_vector)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { @@ -217,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), - LiteralUtil::CreateR0(!direction).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(direction), + LiteralUtil::CreateR0(!direction)}); - ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()}, error_spec_); } } @@ -286,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { @@ -331,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec1), LiteralUtil::CreateR1(vec2)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { @@ -407,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { @@ -422,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + LiteralUtil::MakeTuple({&expected_v1, &expected_s}); auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); - auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { @@ -445,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( - { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), - }) - .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}), })) .ConsumeValueOrDie(); @@ -483,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0({1, 2}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) - .get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}}) - .get()}) - .get()})) + {{10000, 20000}, {30000, 40000}}})})})) .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ ->TransferToServer( - *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) + LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); auto sum = LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = MakeUnique(sum->shape()); - ASSERT_TRUE(prod->Populate( - [&sum](tensorflow::gtl::ArraySlice indexes) { - return sum->Get(indexes) * - (indexes[indexes.size() - 1] == 0 - ? complex64(1, 2) - : complex64(1, -2)); - }) + Literal prod(sum.shape()); + ASSERT_TRUE(prod.Populate([&sum](absl::Span indexes) { + return sum.Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) .ok()); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), - LiteralUtil::CreateR0({123, 456}).get()}); - ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices({prod, sum}), + LiteralUtil::CreateR0({123, 456})}); + ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -540,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); - auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), + result)); } // Disabled on interpreter due to lack of outfeed. @@ -580,16 +570,15 @@ XLA_TEST_F(TupleHloTest, tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { TF_EXPECT_OK(Execute(std::move(module), - {param0.get(), param1.get(), param1.get(), - param0.get(), param4.get()}) + {¶m0, ¶m1, ¶m1, ¶m0, ¶m4}) .status()); })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), literal.get())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); + backend().default_stream_executor(), expected.shape(), literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 20ae68ab74026936c43e5f525eb796eb402a19cb..4fbd7f2fb174ac899c1e3b23801986cb52db96a2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper() { {-inf(), 0}}); Abs(arg); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = LiteralUtil::CreateR1( + Literal expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper() { auto abs = Abs(arg); Sub(Mul(sign, ConvertElementType(abs, C64)), arg); - std::unique_ptr expected = - LiteralUtil::CreateR1({0, 0, 0, 0}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { @@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(sgnc, ConvertElementType( Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); - std::unique_ptr expected = - LiteralUtil::CreateR0({-2.6f, 0.8f}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { @@ -190,25 +188,6 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); } -XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}); - Abs(arg); - - ComputeAndCompareR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); -} - -XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}); - Sign(arg); - - ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); -} - XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1bdf1867b9330b715b0ba4aca71d56307883c775..8b1b9e151992296b9d022ae1d9d974eadd2074a8 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // have all reached 2.0. auto expected_data = LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = LiteralUtil::MakeTuple({expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { @@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple( + {&expected_counter, &expected_w2, &expected_w3, &expected_w1}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { @@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPredicateTupleResult) { @@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_predicate = LiteralUtil::CreateR0(true); - auto expected = LiteralUtil::MakeTuple( - {expected_counter.get(), expected_predicate.get()}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); + auto expected = + LiteralUtil::MakeTuple({&expected_counter, &expected_predicate}); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { @@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR0(7); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests two while nodes when the result type T is a Tuple and the second @@ -766,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -// Test while nodes that share the while body computation. -// TODO(b/37245345): Fails on GPU backend. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { +TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -886,10 +881,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector of S32. @@ -977,11 +971,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({&expected_element, &expected_element}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1005,7 +999,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1031,7 +1025,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(42))); + client_->TransferToServer(LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1070,12 +1064,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(1))); + client_->TransferToServer(LiteralUtil::CreateR0(1))); auto add1 = LiteralUtil::CreateR0(15); auto add2 = LiteralUtil::CreateR0(16); - auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + auto expected = LiteralUtil::MakeTuple({&add1, &add2}); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1228,7 +1222,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN( - auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + auto param_value, client_->TransferToServer(LiteralUtil::CreateR2( {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( @@ -1258,9 +1252,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { XlaBuilder builder(TestName()); While(condition, body, ConstantR0(&builder, 0)); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(false))); ComputeAndCompareR0(&builder, 2, {}); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 11f3efb1f34ad23ebdcbb65c90aa5fb7a6adeae5..db5a824de08edeb81b5deb047507dc6158833008 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -16,6 +16,10 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -29,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -81,8 +84,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, gtl::FlatMap* parsed_results, - tensorflow::gtl::ArraySlice opcodes_to_ignore = - {}) { + absl::Span opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; @@ -99,7 +101,7 @@ Status ParseOneProfileOutputLine( string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; - string regexp_pattern = tensorflow::strings::StrCat( + string regexp_pattern = absl::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, match_bytes_per_cycle, separator, match_opcode); @@ -116,7 +118,7 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); } @@ -142,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -169,10 +171,10 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, ServiceExecutableRunOptions run_options( exec_run_options, /*borrow_stream=*/nullptr, backend->eigen_intra_op_thread_pool()); + std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, - executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, - &hlo_execution_profile)); + executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile)); TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; @@ -204,7 +206,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { rhs_shape); std::vector profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); gtl::FlatMap parsed_profile_lines; @@ -291,22 +293,20 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { matrix_shape); std::vector profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); auto while_body_profile_start = - c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith(s, - "Execution profile for body"); + absl::c_find_if(profile_output_lines, [](absl::string_view s) { + return absl::StartsWith(s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = - std::find_if(while_body_profile_start, profile_output_lines.end(), - [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith( - s, "********** microseconds report **********"); - }); + auto while_body_profile_end = std::find_if( + while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds report **********"); + }); // We emit a blank line before the "********** microseconds report **********" // line. diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a075195618c42aaa11f7b1c17730e67889a2c308..15603619b62d8f45cdce97ac7d83924a78f88cf3 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -32,16 +32,14 @@ GTEST_API_ int main(int argc, char** argv) { // If the --benchmarks flag is passed in then only run the benchmarks, not the // tests. for (int i = 1; i < argc; i++) { - tensorflow::StringPiece arg(argv[i]); - if (arg == "--benchmarks" || - tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + absl::string_view arg(argv[i]); + if (arg == "--benchmarks" || absl::StartsWith(arg, "--benchmarks=")) { const char* pattern = nullptr; - if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + if (absl::StartsWith(arg, "--benchmarks=")) { pattern = argv[i] + strlen("--benchmarks="); } else { // Handle flag of the form '--benchmarks foo' (no '='). - if (i + 1 >= argc || - tensorflow::str_util::StartsWith(argv[i + 1], "--")) { + if (i + 1 >= argc || absl::StartsWith(argv[i + 1], "--")) { LOG(ERROR) << "--benchmarks flag requires an argument."; return 2; } diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 897123d7606db60abc1105b03beb3f23ab249579..cdde88c1359416d423685f330e9cbdf77948034f 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,25 +20,27 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { -StatusOr> TextLiteralReader::ReadPath( - tensorflow::StringPiece path) { - CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) +StatusOr TextLiteralReader::ReadPath(absl::string_view path) { + CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = @@ -54,34 +56,7 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -namespace { -// This is an optimized version of tensorflow::str_util::Split which uses -// StringPiece for the delimited strings and uses an out parameter for the -// result to avoid vector creation/destruction. -void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim, - std::vector* result) { - result->clear(); - - if (text.empty()) { - return; - } - - // The following loop is a little strange: its bound is text.size() + 1 - // instead of the more typical text.size(). - // The final iteration of the loop (when i is equal to text.size()) handles - // the trailing token. - size_t token_start = 0; - for (size_t i = 0; i < text.size() + 1; i++) { - if (i == text.size() || text[i] == delim) { - tensorflow::StringPiece token(text.data() + token_start, i - token_start); - result->push_back(token); - token_start = i + 1; - } - } -} -} // namespace - -StatusOr> TextLiteralReader::ReadAllLines() { +StatusOr TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); string shape_string; @@ -90,63 +65,57 @@ StatusOr> TextLiteralReader::ReadAllLines() { return s; } - tensorflow::StringPiece sp(shape_string); - if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = std::string(sp); - shape_string = tmp; - } + absl::StripAsciiWhitespace(&shape_string); TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } - auto result = MakeUnique(shape); + Literal result(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill); - std::vector pieces; - std::vector coordinates; + result.PopulateWithValue(fill); + std::vector pieces; + std::vector coordinates; std::vector coordinate_values; string line; while (buf.ReadLine(&line).ok()) { - SplitByDelimToStringPieces(line, ':', &pieces); - tensorflow::StringPiece coordinates_string = pieces[0]; - tensorflow::StringPiece value_string = pieces[1]; - tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); - tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { + pieces = absl::StrSplit(line, ':'); + absl::string_view coordinates_string = + absl::StripAsciiWhitespace(pieces[0]); + absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]); + if (!absl::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( - "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); + "expected '(' at the beginning of coordinates: \"%s\"", line); } - if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) { + if (!absl::ConsumeSuffix(&coordinates_string, ")")) { return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", - line.c_str()); + line); } float value; - if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), - &value)) { + if (!absl::SimpleAtof(value_string, &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - std::string(value_string).c_str()); + value_string); } - SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); + coordinates = absl::StrSplit(coordinates_string, ','); coordinate_values.clear(); - for (tensorflow::StringPiece piece : coordinates) { + for (absl::string_view piece : coordinates) { int64 coordinate_value; - if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { + if (!absl::SimpleAtoi(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - std::string(piece).c_str()); + std::string(piece)); } coordinate_values.push_back(coordinate_value); } if (coordinate_values.size() != shape.dimensions_size()) { return InvalidArgument( - "line did not have expected number of coordinates; want %d got %zu: " + "line did not have expected number of coordinates; want %d got %u: " "\"%s\"", - shape.dimensions_size(), coordinate_values.size(), line.c_str()); + shape.dimensions_size(), coordinate_values.size(), line); } - result->Set(coordinate_values, value); + result.Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 708e8c80d8b5c09454eb64d4e12df51a5b7ea628..c40b43279f56fbd6e8ec91cc45c1f8e4cac8b5ef 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -41,8 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath( - tensorflow::StringPiece path); + static StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -50,7 +49,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr> ReadAllLines(); + StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 92f9b4f9f0efa2dc08287bdcbefc88f879164308..1fab4e3a08dd3d76a6efeaabe7bf8ab96892e638 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) { tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents) .ok()); - std::unique_ptr literal = - TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); + Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, literal->Get({0, 0, 0})); - EXPECT_EQ(43.5, literal->Get({0, 0, 1})); - EXPECT_EQ(44.5, literal->Get({0, 0, 2})); - EXPECT_EQ(45.5, literal->Get({0, 1, 0})); - EXPECT_EQ(46.5, literal->Get({0, 1, 1})); - EXPECT_EQ(47.5, literal->Get({0, 1, 2})); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape())); + EXPECT_EQ(42.5, literal.Get({0, 0, 0})); + EXPECT_EQ(43.5, literal.Get({0, 0, 1})); + EXPECT_EQ(44.5, literal.Get({0, 0, 2})); + EXPECT_EQ(45.5, literal.Get({0, 1, 0})); + EXPECT_EQ(46.5, literal.Get({0, 1, 1})); + EXPECT_EQ(47.5, literal.Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 24e0784741a4c9779b0adb7a7740c3d6e2fb033a..7289ae7df65e56652eeeb67e536e4c721d97d999 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,23 +17,23 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" namespace xla { -/* static */ Status TextLiteralWriter::WriteToPath( - const Literal& literal, tensorflow::StringPiece path) { +/* static */ Status TextLiteralWriter::WriteToPath(const Literal& literal, + absl::string_view path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(string(path), &f); if (!s.ok()) { return s; } @@ -46,16 +46,14 @@ namespace xla { Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( - [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + [f_ptr, &status](absl::Span indices, const string& value) { if (!status.ok()) { return; } - string coordinates = tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(indices, ", "), ")"); + string coordinates = + absl::StrCat("(", absl::StrJoin(indices, ", "), ")"); - status = f_ptr->Append( - tensorflow::strings::StrCat(coordinates, ": ", value, "\n")); + status = f_ptr->Append(absl::StrCat(coordinates, ": ", value, "\n")); }); auto ignored = f->Close(); return status; diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 159ac1b7e1b6f9c07dac795fb640cd0b2d284bcb..34de8572d638067b327711017ee173b16c8da21e 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -37,8 +37,7 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, absl::string_view path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 4ea02faffcd52065b05c0444202bd1a3d9d87ee6..5cbaf2fcc192c48092272094710ccaf5c9cf9616 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) { }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); - ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path)); + ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path)); string contents; TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &contents)); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 40d28a57bfddd3403cad8252df985b746362631f..3a086c66bbb37965b1ad7c83a93f0054ae723e87 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -24,6 +24,8 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", ], ) @@ -42,6 +44,7 @@ cc_library( "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -67,6 +70,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -94,6 +98,7 @@ cc_library( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], alwayslink = True, ) @@ -172,6 +177,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -191,6 +197,9 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -210,6 +219,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index f20dcef382b86d27d7c176ae7e4132ad1db7b901..c866a13de7543fc948311f94708bc6b904717b62 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -46,7 +46,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -77,8 +77,8 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index f0af0580c1fbca455c6ed5f87f82971faee50a06..4375e7c138c9e8d193feaa7a39d63946c4ea3086 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -19,6 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -29,9 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -44,16 +44,14 @@ class OperationDumper : public DfsHloVisitorWithDefault { explicit OperationDumper(const string& path) : path_(path) {} Status DefaultAction(HloInstruction* hlo) override { - string params = tensorflow::str_util::Join( + string params = absl::StrJoin( hlo->operands(), ", ", [](string* out, const HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + absl::StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); // Spit `op_name(params...) -> result_type :: path` to stdout. - std::cout << tensorflow::strings::Printf( - "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(), - params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(), - path_.c_str()); + std::cout << absl::StrFormat("%s :: (%s) -> %s :: %s\n", + HloOpcodeString(hlo->opcode()), params, + ShapeUtil::HumanString(hlo->shape()), path_); return Status::OK(); } @@ -61,7 +59,7 @@ class OperationDumper : public DfsHloVisitorWithDefault { string path_; }; -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -106,8 +104,8 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index f03e1b1f965af761c101555fd0275bc0425b9cf0..723569862c7550387e95003e3a673743464b67b8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -34,7 +34,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { +void RealMain(absl::Span args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -102,8 +102,8 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc..07ef5ff656bb48519a700a1d7d6c60b655a40ed6 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -45,7 +45,7 @@ using tensorflow::Env; namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -78,8 +78,8 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index eb7bff053b1fc028fdb6930dbc496c3b6d9fae47..0c3ec5934e546f551089f715dbbe6f4479e56c3c 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/platform/env.h" @@ -67,9 +67,8 @@ int main(int argc, char** argv) { floats.push_back(value); } - tensorflow::StringPiece content( - tensorflow::bit_cast(floats.data()), - floats.size() * sizeof(float)); + tensorflow::StringPiece content(absl::bit_cast(floats.data()), + floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output_file, content)); return 0; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index be4cf4318b33f41fc611ea90a1a02198e23b84e4..0c41f227b31ebe1f01073785ea2a666093aefdb3 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -40,6 +40,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -59,7 +60,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -121,11 +121,10 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } else { // use recorded data if available for (const auto& proto : module.arguments()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - Literal::CreateFromProto(proto)); + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer data, - client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); scoped_shaped_buffer_arguments.push_back(std::move(data)); } for (const auto& argument : scoped_shaped_buffer_arguments) { @@ -160,13 +159,13 @@ StatusOr ReplayComputation(const HloSnapshot& module, // concurrent infeed occur via the fake_infeed_shape, or when // --generate_fake_infeed is passed and there exists an infeed operation in // the HloSnapshot. - tensorflow::gtl::optional pool; - std::unique_ptr data; + absl::optional pool; + Literal data; if (provide_infeed) { data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); } auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(*data)); + TF_CHECK_OK(client->TransferToInfeed(data)); }; if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", @@ -196,7 +195,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); - tensorflow::gtl::optional result; + absl::optional result; for (int i = 0; i < opts.num_runs; ++i) { // If xla_hlo_profile is enabled, print a noisy message before the last run, // making it easier to separate this profile from the others in the logspam. @@ -214,18 +213,22 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "s: " << module.hlo().hlo_module().name(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + TF_ASSIGN_OR_RETURN(Literal result_literal, client->ShapedBufferToLiteral(*result)); - return std::move(*result_literal); + return result_literal; } StatusOr ParseInputFile(const string& filename, const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); HloSnapshot snapshot; - if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot); + if (s.ok()) { return snapshot; } + if (s.code() == tensorflow::error::NOT_FOUND) { + return s; + } CHECK(opts.use_fake_data) << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " "and textual HLO don't carry real data."; @@ -246,10 +249,10 @@ StatusOr ParseInputFile(const string& filename, } fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", filename.c_str()); - return InvalidArgument("Could not parse %s.", filename.c_str()); + return InvalidArgument("Could not parse %s.", filename); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { +int RealMain(absl::Span args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; @@ -258,6 +261,9 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { StatusOr maybe_snapshot = ParseInputFile(arg, opts); if (maybe_snapshot.ok()) { snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); + } else { + LOG(ERROR) << "Can't handle file " << arg << ": " + << maybe_snapshot.status(); } } @@ -298,11 +304,11 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { result.ToString().c_str()); auto& snapshot = snapshots[i]; if (snapshot.has_result()) { - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); + literal.ToString().c_str()); } } } @@ -337,7 +343,7 @@ int main(int argc, char** argv) { LOG(QFATAL) << usage; } - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index 51909190a3ef20c3df78d08796e88bdbb650609d..4f8852f8c11fb749ef851bc4faf176fcc5cb3524 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,8 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - std::unique_ptr literal = + xla::Literal literal = xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal->ToString().c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 4e53fafcc97ff53afc5713e7ed8ee5222fac316b..cdf306dfd1027cf6022c5d8ae844b4308f580e8d 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -45,7 +45,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -66,8 +66,8 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 48c837481181f6ad8f864569fd62e0e23fa02ecd..4b5c276bdf66f3dc5364aae4654b13a625b0a4f7 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -36,16 +36,16 @@ int main(int argc, char **argv) { LOG(QFATAL) << "Usage: " << argv[0] << " "; } - std::unique_ptr literal = + xla::Literal literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << *literal; - fprintf(stderr, "%s\n", literal->ToString().c_str()); - if (literal->shape().element_type() == xla::F32) { - float min = *std::min_element(literal->data().begin(), - literal->data().end()); - float max = *std::max_element(literal->data().begin(), - literal->data().end()); + LOG(INFO) << "literal: " << literal; + fprintf(stderr, "%s\n", literal.ToString().c_str()); + if (literal.shape().element_type() == xla::F32) { + float min = *std::min_element(literal.data().begin(), + literal.data().end()); + float max = *std::max_element(literal.data().begin(), + literal.data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e43498e381b8e63543e2ddda08ca7c0df91817e4..68cab7387cf1576072f96878b50f07def6862d8b 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stacktrace.h" @@ -54,111 +55,28 @@ ScopedLoggingTimer::~ScopedLoggingTimer() { } } -Status AddStatus(Status prior, tensorflow::StringPiece context) { +Status AddStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat( - context, ": ", prior.error_message())}; + return Status{prior.code(), + absl::StrCat(context, ": ", prior.error_message())}; } -Status AppendStatus(Status prior, tensorflow::StringPiece context) { +Status AppendStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(), - ": ", context)}; + return Status{prior.code(), + absl::StrCat(prior.error_message(), ": ", context)}; } -// Implementation note: we can't common these out (without using macros) because -// they all need to va_start/va_end their varargs in their frame. - -Status InvalidArgumentV(const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); -} - -Status InvalidArgument(const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -Status Unimplemented(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unimplemented(message)); -} - -Status InternalError(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Internal(message)); -} - -Status FailedPrecondition(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message)); -} - -Status Cancelled(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Cancelled(message)); -} - -Status ResourceExhausted(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message)); -} - -Status NotFound(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::NotFound(message)); -} - -Status Unavailable(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unavailable(message)); -} - -string Reindent(tensorflow::StringPiece original, - const tensorflow::StringPiece indentation) { - std::vector pieces = tensorflow::str_util::Split( - tensorflow::StringPiece(original.data(), original.size()), '\n'); - return tensorflow::str_util::Join( - pieces, "\n", [indentation](string* out, string s) { - tensorflow::StringPiece piece(s); - tensorflow::str_util::RemoveWhitespaceContext(&piece); - tensorflow::strings::StrAppend(out, indentation, piece); - }); +string Reindent(absl::string_view original, + const absl::string_view indentation) { + std::vector pieces = + absl::StrSplit(absl::string_view(original.data(), original.size()), '\n'); + return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) { + absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s)); + }); } -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { +bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } @@ -172,7 +90,7 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { } std::vector InversePermutation( - tensorflow::gtl::ArraySlice input_permutation) { + absl::Span input_permutation) { DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { @@ -181,8 +99,8 @@ std::vector InversePermutation( return output_permutation; } -std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, - tensorflow::gtl::ArraySlice p2) { +std::vector ComposePermutations(absl::Span p1, + absl::Span p2) { CHECK_EQ(p1.size(), p2.size()); std::vector output; for (size_t i = 0; i < p1.size(); ++i) { @@ -191,7 +109,7 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { +bool IsIdentityPermutation(absl::Span permutation) { for (int64 i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) { return false; @@ -212,7 +130,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { } PaddingConfig MakeEdgePaddingConfig( - tensorflow::gtl::ArraySlice> padding) { + absl::Span> padding) { PaddingConfig padding_config; for (const std::pair& dim : padding) { auto dimension = padding_config.add_dimensions(); @@ -234,20 +152,20 @@ bool HasInteriorPadding(const PaddingConfig& config) { namespace { string HumanReadableNumOps(double flops, double nanoseconds, - tensorflow::StringPiece op_prefix) { + absl::string_view op_prefix) { if (nanoseconds == 0) { - return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s"); + return absl::StrCat("NaN ", op_prefix, "OP/s"); } double nano_flops = flops / nanoseconds; string throughput = tensorflow::strings::HumanReadableNum( static_cast(nano_flops * 1e9)); - tensorflow::StringPiece sp(throughput); + absl::string_view sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case - tensorflow::str_util::EndsWith(sp, "b")) { + if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case + absl::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } - throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); + throughput += absl::StrCat(op_prefix, "OP/s"); return throughput; } } // namespace @@ -260,8 +178,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) { return HumanReadableNumOps(trops, nanoseconds, "TR"); } -void LogLines(int sev, tensorflow::StringPiece text, const char* fname, - int lineno) { +void LogLines(int sev, absl::string_view text, const char* fname, int lineno) { const int orig_sev = sev; if (sev == tensorflow::FATAL) { sev = tensorflow::ERROR; @@ -275,7 +192,7 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, size_t cur = 0; while (cur < text.size()) { size_t eol = text.find('\n', cur); - if (eol == tensorflow::StringPiece::npos) { + if (eol == absl::string_view::npos) { eol = text.size(); } auto msg = text.substr(cur, eol - cur); @@ -290,14 +207,13 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, } } -int64 Product(tensorflow::gtl::ArraySlice xs) { +int64 Product(absl::Span xs) { return std::accumulate(xs.begin(), xs.end(), static_cast(1), std::multiplies()); } -std::vector> CommonFactors( - tensorflow::gtl::ArraySlice a, - tensorflow::gtl::ArraySlice b) { +std::vector> CommonFactors(absl::Span a, + absl::Span b) { CHECK_EQ(Product(a), Product(b)); if (0 == Product(a)) { return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())}; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 5ae099a4622bb7116c7a17f93060b699ead6e3a6..8ce741647414a1fa75e6d706ec1e719ace7b7cc8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -24,17 +24,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -54,7 +57,7 @@ Status WithLogBacktrace(const Status& status); // the InlinedVector will just behave like an std::vector<> and allocate the // memory to store its values. static constexpr int kInlineRank = 8; -using DimensionVector = tensorflow::gtl::InlinedVector; +using DimensionVector = absl::InlinedVector; // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it @@ -98,65 +101,63 @@ struct ScopedLoggingTimer { uint64 start_micros; }; -// Given a vector, returns a MutableArraySlice that points at its +// Given a vector, returns a Span that points at its // internals. // // Warning: if the vector is updated its storage pointer may change, so use this // with caution (ideally in limited scopes with temporary lifetimes). template -tensorflow::gtl::MutableArraySlice MutableByteSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(v->data()), v->size() * sizeof(T)); +absl::Span MutableByteSlice(std::vector* v) { + return absl::Span(reinterpret_cast(v->data()), + v->size() * sizeof(T)); } // Turns an immutable slice of type T into an immutable slice of bytes with the // same byte size. template -tensorflow::gtl::ArraySlice CastToByteSlice( - tensorflow::gtl::ArraySlice slice) { - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() * sizeof(T)); +absl::Span CastToByteSlice(absl::Span slice) { + return absl::Span(reinterpret_cast(slice.data()), + slice.size() * sizeof(T)); } // Casts a byte slice to a non-byte type T, checking that the original slice // length is a multiple of sizeof(T). template -tensorflow::gtl::ArraySlice CastByteSlice( - tensorflow::gtl::ArraySlice slice) { +absl::Span CastByteSlice(absl::Span slice) { CHECK_EQ(0, slice.size() % sizeof(T)); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() / sizeof(T)); + return absl::Span(reinterpret_cast(slice.data()), + slice.size() / sizeof(T)); } // Convenience function to force a vector to convert to an immutable slice. template -tensorflow::gtl::ArraySlice AsSlice(const std::vector& v) { - return tensorflow::gtl::ArraySlice(v); +absl::Span AsSlice(const std::vector& v) { + return absl::Span(v); } -// Converts a mutable vector pointer into a MutableArraySlice of the same +// Converts a mutable vector pointer into a Span of the same // type. template -tensorflow::gtl::MutableArraySlice AsMutableSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice(v->data(), v->size()); +absl::Span AsMutableSlice(std::vector* v) { + return absl::Span(v->data(), v->size()); } // xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source. // Wrapper function that gives an int64 array slice view of a repeated int64 // protobuf field. -static inline tensorflow::gtl::ArraySlice AsInt64Slice( +static inline absl::Span AsInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // As above, but for uint64 types. -static inline tensorflow::gtl::ArraySlice AsUInt64Slice( +static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // Compares two containers for equality. Returns true iff the two containers @@ -172,7 +173,7 @@ template bool ContainersEqual(const Container1T& c1, std::initializer_list il) { - tensorflow::gtl::ArraySlice c2{il}; + absl::Span c2{il}; return ContainersEqual(c1, c2); } @@ -190,9 +191,9 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, // source and destination. The source starting index is src_base, while the // destination one is dest_base. template -void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, - int64 dest_stride, tensorflow::gtl::ArraySlice src, - int64 src_base, int64 src_stride, int64 count) { +void StridedCopy(absl::Span dest, int64 dest_base, int64 dest_stride, + absl::Span src, int64 src_base, int64 src_stride, + int64 count) { for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { dest[dest_base] = static_cast(src[src_base]); } @@ -201,46 +202,76 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. -Status AddStatus(Status prior, tensorflow::StringPiece context); -Status AppendStatus(Status prior, tensorflow::StringPiece context); - -// Status error shorthands -- printfs the arguments to be -// used as an error message and returns a status in the canonical -// error space. -Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Cancelled(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); - -// Passed-varargs variant of the InvalidArgument factory above. -Status InvalidArgumentV(const char* format, va_list args); +Status AddStatus(Status prior, absl::string_view context); +Status AppendStatus(Status prior, absl::string_view context); + +// Status error shorthands -- StrFormat's the arguments to be used as an error +// message and returns a status in the canonical error space. +template +Status InvalidArgument(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...))); +} +template +Status Unimplemented(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unimplemented(absl::StrFormat(format, args...))); +} +template +Status InternalError(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Internal(absl::StrFormat(format, args...))); +} +template +Status FailedPrecondition(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...))); +} +template +Status Cancelled(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Cancelled(absl::StrFormat(format, args...))); +} +template +Status ResourceExhausted(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...))); +} +template +Status NotFound(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::NotFound(absl::StrFormat(format, args...))); +} +template +Status Unavailable(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unavailable(absl::StrFormat(format, args...))); +} template Status InvalidArgumentStrCat(Args&&... concat) { - return InvalidArgument( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return InvalidArgument("%s", absl::StrCat(std::forward(concat)...)); } template Status UnimplementedStrCat(Args&&... concat) { - return Unimplemented( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return Unimplemented("%s", absl::StrCat(std::forward(concat)...)); } template Status InternalErrorStrCat(Args&&... concat) { - return InternalError( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return InternalError("%s", absl::StrCat(std::forward(concat)...)); } template Status ResourceExhaustedStrCat(Args&&... concat) { - return ResourceExhausted( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return ResourceExhausted("%s", absl::StrCat(std::forward(concat)...)); } // Splits the lines of the original, replaces leading whitespace with the prefix @@ -249,11 +280,10 @@ Status ResourceExhaustedStrCat(Args&&... concat) { // // Note: even different amounts of leading whitespace on different lines will be // uniformly replaced with "indentation". -string Reindent(tensorflow::StringPiece original, - tensorflow::StringPiece indentation); +string Reindent(absl::string_view original, absl::string_view indentation); // Checks whether permutation is a permutation of the [0, rank) integer range. -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); +bool IsPermutation(absl::Span permutation, int64 rank); // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. @@ -261,10 +291,11 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); // Precondition: // 1. `permutation` is a permutation of 0..permutation.size()-1. // 2. permutation.size() == input.size(). -template